"vscode:/vscode.git/clone" did not exist on "9661bd57466c445545a4f432133d1581330fd8a1"
compile_hip_code_object.cpp 6.02 KB
Newer Older
1
2
3
4
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp>
5
#include <migraphx/gpu/device_name.hpp>
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <migraphx/context.hpp>
#include <migraphx_kernels.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

template <class T>
std::string generate_index_ints(const std::vector<T>& v)
{
    return "index_ints<" + to_string_range(v) + ">{}";
}

std::string generate_make_shape(const shape& s)
{
    return "make_shape(" + generate_index_ints(s.lens()) + ", " + generate_index_ints(s.strides()) +
           ")";
}

static const char* const make_tensor_template = R"__migraphx__(
template<>
struct make_tensor<${n}>
{
    static __device__ auto apply(void* p)
    {
        return make_tensor_view(reinterpret_cast<${type}*>(p), make_shape(${lens}, ${strides}));
    }
};
)__migraphx__";

std::string generate_make_tensor(std::size_t n, const shape& s)
{
    return interpolate_string(make_tensor_template,
                              {{"n", std::to_string(n)},
Paul Fultz II's avatar
Paul Fultz II committed
42
                               {"type", shape::cpp_type(s.type())},
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
                               {"lens", generate_index_ints(s.lens())},
                               {"strides", generate_index_ints(s.strides())}});
}

std::string generate_args_hpp(const std::vector<shape>& inputs)
{
    std::string inner;
    for(std::size_t i = 0; i < inputs.size(); i++)
    {
        inner += generate_make_tensor(i, inputs[i]);
    }
    const std::string args_hpp = R"__migraphx__(
#ifndef MIGRAPHX_GUARD_AUTO_ARGS_HPP
#define MIGRAPHX_GUARD_AUTO_ARGS_HPP

#include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/tensor_view.hpp>

namespace migraphx {

__content__

} // namespace migraphx
#endif
)__migraphx__";
    return replace_string(args_hpp, "__content__", inner);
}

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
const std::vector<std::string>& compiler_warnings()
{
    static std::vector<std::string> warnings = {"-Weverything",
                                                "-Wno-c++98-compat",
                                                "-Wno-c++98-compat-pedantic",
                                                "-Wno-conversion",
                                                "-Wno-double-promotion",
                                                "-Wno-exit-time-destructors",
                                                "-Wno-extra-semi",
                                                "-Wno-extra-semi-stmt",
                                                "-Wno-float-conversion",
                                                "-Wno-gnu-anonymous-struct",
                                                "-Wno-gnu-zero-variadic-macro-arguments",
                                                "-Wno-missing-prototypes",
                                                "-Wno-nested-anon-types",
                                                "-Wno-padded",
                                                "-Wno-shorten-64-to-32",
                                                "-Wno-sign-conversion",
                                                "-Wno-sign-compare",
                                                "-Wno-unused-command-line-argument",
                                                "-Wno-weak-vtables",
                                                "-Wno-c99-extensions"};
    return warnings;
}

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
void hip_compile_options::set_launch_params(
    const value& v,
    const std::function<std::size_t(std::size_t local)>& compute_global,
    std::size_t default_local)
{
    local = v.get("local", default_local);
    if(v.contains("global"))
        global = v.at("global").to<std::size_t>();
    else
        global = compute_global(local);
}

std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over)
{
    assert(over > 0);
    std::size_t max_global = ctx.get_current_device().get_cu_count() *
                             ctx.get_current_device().get_max_workitems_per_cu();
    return [n, over, max_global](std::size_t local) {
        std::size_t groups     = (n + local - 1) / local;
        std::size_t max_blocks = max_global / local;
        std::size_t nglobal    = std::min(max_blocks * over, groups) * local;
        return nglobal;
    };
}

122
123
124
125
126
127
128
129
130
131
132
133
134
135
operation compile_hip_code_object(const std::string& content, hip_compile_options options)
{
    std::vector<src_file> srcs;
    std::transform(migraphx_kernels().begin(),
                   migraphx_kernels().end(),
                   std::back_inserter(srcs),
                   [](auto&& p) {
                       auto&& name = p.first;
                       auto&& c    = p.second;
                       auto path   = fs::path{"migraphx"} / "kernels" / name;
                       return src_file{path, c};
                   });
    srcs.push_back(src_file{fs::path{"main.cpp"},
                            std::make_pair(content.data(), content.data() + content.size())});
136
    auto args_hpp =
137
        generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
138
139
    srcs.push_back(src_file{fs::path{"args.hpp"},
                            std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
140
141
    options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
    options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
142
    options.params += " " + join_strings(compiler_warnings(), " ");
143
    options.params += " -ftemplate-backtrace-limit=0";
144
    options.params += " -Werror";
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
    if(cos.size() != 1)
        MIGRAPHX_THROW("No code object");
    return code_object_op{value::binary{cos.front()},
                          options.kernel_name,
                          options.global,
                          options.local,
                          options.inputs,
                          options.output};
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx