compile_hip_code_object.cpp 7.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
24
25
26
27
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp>
28
#include <migraphx/gpu/device_name.hpp>
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <migraphx/context.hpp>
#include <migraphx_kernels.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

template <class T>
std::string generate_index_ints(const std::vector<T>& v)
{
    return "index_ints<" + to_string_range(v) + ">{}";
}

std::string generate_make_shape(const shape& s)
{
    return "make_shape(" + generate_index_ints(s.lens()) + ", " + generate_index_ints(s.strides()) +
           ")";
}

static const char* const make_tensor_template = R"__migraphx__(
template<>
struct make_tensor<${n}>
{
54
    static __device__ auto apply(void* __restrict__ p)
55
    {
56
        return make_tensor_view(reinterpret_cast<${type}* __restrict__>(p), make_shape(${lens}, ${strides}));
57
58
59
60
61
62
63
64
    }
};
)__migraphx__";

std::string generate_make_tensor(std::size_t n, const shape& s)
{
    return interpolate_string(make_tensor_template,
                              {{"n", std::to_string(n)},
Paul Fultz II's avatar
Paul Fultz II committed
65
                               {"type", shape::cpp_type(s.type())},
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
                               {"lens", generate_index_ints(s.lens())},
                               {"strides", generate_index_ints(s.strides())}});
}

std::string generate_args_hpp(const std::vector<shape>& inputs)
{
    std::string inner;
    for(std::size_t i = 0; i < inputs.size(); i++)
    {
        inner += generate_make_tensor(i, inputs[i]);
    }
    const std::string args_hpp = R"__migraphx__(
#ifndef MIGRAPHX_GUARD_AUTO_ARGS_HPP
#define MIGRAPHX_GUARD_AUTO_ARGS_HPP

#include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/tensor_view.hpp>

namespace migraphx {

__content__

} // namespace migraphx
#endif
)__migraphx__";
    return replace_string(args_hpp, "__content__", inner);
}

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
const std::vector<std::string>& compiler_warnings()
{
    static std::vector<std::string> warnings = {"-Weverything",
                                                "-Wno-c++98-compat",
                                                "-Wno-c++98-compat-pedantic",
                                                "-Wno-conversion",
                                                "-Wno-double-promotion",
                                                "-Wno-exit-time-destructors",
                                                "-Wno-extra-semi",
                                                "-Wno-extra-semi-stmt",
                                                "-Wno-float-conversion",
                                                "-Wno-gnu-anonymous-struct",
                                                "-Wno-gnu-zero-variadic-macro-arguments",
                                                "-Wno-missing-prototypes",
                                                "-Wno-nested-anon-types",
                                                "-Wno-padded",
                                                "-Wno-shorten-64-to-32",
                                                "-Wno-sign-conversion",
                                                "-Wno-sign-compare",
                                                "-Wno-unused-command-line-argument",
                                                "-Wno-weak-vtables",
                                                "-Wno-c99-extensions"};
    return warnings;
}

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
void hip_compile_options::set_launch_params(
    const value& v,
    const std::function<std::size_t(std::size_t local)>& compute_global,
    std::size_t default_local)
{
    local = v.get("local", default_local);
    if(v.contains("global"))
        global = v.at("global").to<std::size_t>();
    else
        global = compute_global(local);
}

std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over)
{
    assert(over > 0);
    std::size_t max_global = ctx.get_current_device().get_cu_count() *
                             ctx.get_current_device().get_max_workitems_per_cu();
    return [n, over, max_global](std::size_t local) {
        std::size_t groups     = (n + local - 1) / local;
        std::size_t max_blocks = max_global / local;
        std::size_t nglobal    = std::min(max_blocks * over, groups) * local;
        return nglobal;
    };
}

Paul Fultz II's avatar
Paul Fultz II committed
145
146
147
148
149
150
151
152
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size)
{
    size_t block_size = 128;
    while(block_size <= max_block_size and block_size <= n)
        block_size *= 2;
    return block_size / 2;
}

153
154
operation compile_hip_code_object(const std::string& content, hip_compile_options options)
{
Paul Fultz II's avatar
Paul Fultz II committed
155
156
157
158
159
    assert(options.global > 0);
    assert(options.local > 0);
    assert(not options.inputs.empty());
    assert(options.inputs.size() == options.virtual_inputs.size() or
           options.virtual_inputs.empty());
160
161
162
163
164
165
166
167
168
169
170
171
    std::vector<src_file> srcs;
    std::transform(migraphx_kernels().begin(),
                   migraphx_kernels().end(),
                   std::back_inserter(srcs),
                   [](auto&& p) {
                       auto&& name = p.first;
                       auto&& c    = p.second;
                       auto path   = fs::path{"migraphx"} / "kernels" / name;
                       return src_file{path, c};
                   });
    srcs.push_back(src_file{fs::path{"main.cpp"},
                            std::make_pair(content.data(), content.data() + content.size())});
172
    auto args_hpp =
173
        generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
174
175
    srcs.push_back(src_file{fs::path{"args.hpp"},
                            std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
176
177
    options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
    options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
178
    options.params += " " + join_strings(compiler_warnings(), " ");
179
    options.params += " -ftemplate-backtrace-limit=0";
180
    options.params += " -Werror";
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
    if(cos.size() != 1)
        MIGRAPHX_THROW("No code object");
    return code_object_op{value::binary{cos.front()},
                          options.kernel_name,
                          options.global,
                          options.local,
                          options.inputs,
                          options.output};
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx