ck_gemm.cpp 12.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <fstream>
#include <filesystem>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>

#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
Paul's avatar
Paul committed
32
#include <migraphx/gpu/compile_gen.hpp>
Paul's avatar
Paul committed
33
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
34
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
35
36
37
38
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
39
#include <migraphx/file_buffer.hpp>
Paul's avatar
Paul committed
40

Paul's avatar
Paul committed
41
42
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
Paul's avatar
Paul committed
43

Paul's avatar
Paul committed
44
45
46
47
48
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

namespace gpu {

Paul's avatar
Paul committed
49
50
using namespace migraphx::gpu::gen; // NOLINT

Paul's avatar
Paul committed
51
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
Paul's avatar
Paul committed
52
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
Paul's avatar
Paul committed
53
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
Paul's avatar
Paul committed
54
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
Paul's avatar
Paul committed
55

Paul's avatar
Paul committed
56
57
58
59
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
Paul's avatar
Paul committed
60
#include <migraphx/kernels/pointwise.hpp>
Paul's avatar
Paul committed
61
62
63

namespace migraphx {

Paul's avatar
Paul committed
64
65
${preamble}

Paul's avatar
Paul committed
66
67
extern "C" {

Paul's avatar
Paul committed
68
__global__ void ${kernel}(${params})
Paul's avatar
Paul committed
69
{
Paul's avatar
Paul committed
70
    transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
Paul's avatar
Paul committed
71
        ck_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
Paul's avatar
Paul committed
72
73
74
75
76
77
78
79
80
    });
}

}

} // namespace migraphx

)__migraphx__";

Paul's avatar
Paul committed
81
82
static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }

Paul's avatar
Paul committed
83
struct instance
Paul's avatar
Paul committed
84
{
Paul's avatar
Paul committed
85
86
    std::vector<std::string> params;
    static const std::size_t block_size_index = 15;
Paul's avatar
Paul committed
87

Paul's avatar
Format  
Paul committed
88
    std::size_t int_at(std::size_t i) const { return std::stoull(params[i]); }
Paul's avatar
Paul committed
89

Paul's avatar
Format  
Paul committed
90
    std::size_t get_block_size() const { return int_at(block_size_index); }
Paul's avatar
Paul committed
91
92
93
94
95
96
97
98
99
100

    std::size_t get_pb(std::size_t i) const
    {
        assert(i < 4);
        return int_at(block_size_index + 1 + i);
    }

    std::array<std::size_t, 3> get_pad(const std::array<std::size_t, 3>& config) const
    {
        std::array<std::size_t, 3> result{};
Paul's avatar
Format  
Paul committed
101
        for(auto i : range(config.size()))
Paul's avatar
Paul committed
102
103
104
105
106
107
108
109
        {
            result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i];
        }
        return result;
    }

    std::size_t get_grid_size(const std::array<std::size_t, 3>& config) const
    {
Paul's avatar
Paul committed
110
        return int_div_ceil(config[0], get_pb(0)) * int_div_ceil(config[1], get_pb(1));
Paul's avatar
Paul committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    }

    void set_ds_layout(const std::string& s)
    {
        assert(params[2] == "ck::Tuple<>");
        params[2] = s;
    }

    void set_ds_type(const std::string& s)
    {
        assert(params[8] == "ck::Tuple<>");
        params[8] = s;
    }

    void set_ds_op(const std::string& s)
    {
        assert(params[12] == "ck_passthrough");
        params[12] = s;
    }

    void set_gemm(const std::string& s)
    {
        assert(params[13] == "ck::tensor_operation::device::GemmSpecialization::Default");
        params[13] = s;
    }

Paul's avatar
Format  
Paul committed
137
    std::string str() const { return join_strings(params, ","); }
Paul's avatar
Paul committed
138
};
Paul's avatar
Paul committed
139

Paul's avatar
Paul committed
140
141
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }

Paul's avatar
Format  
Paul committed
142
template <class F, class Action>
Paul's avatar
Paul committed
143
144
145
146
147
148
149
150
auto action_decorate(F f, Action action)
{
    return [=](auto&&... xs) {
        action();
        f(std::forward<decltype(xs)>(xs)...);
    };
}

Paul's avatar
Paul committed
151
152
153
using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s)
{
Paul's avatar
Format  
Paul committed
154
    if(not fs::exists(s))
Paul's avatar
Paul committed
155
156
157
158
        return {};
    return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}

Paul's avatar
Paul committed
159
160
static float matrix_distance(const shape& x, const shape& y)
{
Paul's avatar
Format  
Paul committed
161
    if(x.type() != y.type())
Paul's avatar
Paul committed
162
        return std::numeric_limits<float>::max();
Paul's avatar
Format  
Paul committed
163
    if(transposed_matrix(x) != transposed_matrix(y))
Paul's avatar
Paul committed
164
        return std::numeric_limits<float>::max();
Paul's avatar
Format  
Paul committed
165
166
167
168
169
170
    auto sum_squared = std::inner_product(x.lens().rbegin(),
                                          x.lens().rbegin() + 2,
                                          y.lens().rbegin(),
                                          0,
                                          std::plus<>{},
                                          [](auto a, auto b) { return (a - b) * (a - b); });
Paul's avatar
Paul committed
171
172
173
    return std::sqrt(sum_squared);
}

Paul's avatar
Paul committed
174
175
176
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
    static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
Paul's avatar
Format  
Paul committed
177
    if(tuning.empty())
Paul's avatar
Paul committed
178
        std::cout << "*********** Warning: No CK tuning!" << std::endl;
Paul's avatar
Format  
Paul committed
179
    auto it = std::find_if(
Paul's avatar
Format  
Paul committed
180
        tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
Paul's avatar
Format  
Paul committed
181
182
    if(it == tuning.end())
    {
Paul's avatar
Paul committed
183
        std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
Paul's avatar
Paul committed
184
185
        std::vector<std::pair<float, std::size_t>> w;
        std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
Paul's avatar
Format  
Paul committed
186
            if(inputs.size() < 3 or p.first.size() < 3)
Paul's avatar
Paul committed
187
                MIGRAPHX_THROW("Invalid CK config");
Paul's avatar
Format  
Paul committed
188
189
190
191
192
193
194
            auto avg_distance = std::inner_product(
                p.first.begin(),
                p.first.begin() + 3,
                inputs.begin(),
                0.0f,
                std::plus<>{},
                [](const auto& x, const auto& y) { return matrix_distance(x, y) / 3.0f; });
Paul's avatar
Paul committed
195
196
197
198
            return std::make_pair(avg_distance, p.second);
        });
        std::sort(w.begin(), w.end());
        std::size_t default_value = 4;
Paul's avatar
Format  
Paul committed
199
        if(not w.empty())
Paul's avatar
Paul committed
200
            default_value = w.front().second;
Paul's avatar
Paul committed
201
202
203
        auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
        std::cout << "*********** Warning: CK try tuning: " << tuning_val << std::endl;
        return tuning_val;
Paul's avatar
Paul committed
204
    }
Paul's avatar
Paul committed
205
206
207
    return it->second;
}

Paul's avatar
Paul committed
208
209
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
Paul's avatar
Paul committed
210
211
    static std::string get_layout(const shape& s)
    {
Paul's avatar
Paul committed
212
        return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
Paul's avatar
Format  
Paul committed
213
                                    : "ck::tensor_layout::gemm::RowMajor";
Paul's avatar
Paul committed
214
215
216
    }

    static std::string get_type(const shape& s)
Paul's avatar
Paul committed
217
    {
Paul's avatar
Format  
Paul committed
218
        if(s.type() == shape::half_type)
Paul's avatar
Paul committed
219
220
221
            return "ck::half_t";
        return shape::cpp_type(s.type());
    }
Paul's avatar
Paul committed
222

Paul's avatar
Format  
Paul committed
223
    template <class Iterator, class F>
Paul's avatar
Paul committed
224
225
226
227
228
229
230
    static std::string ck_tuple(Iterator start, Iterator last, F f)
    {
        std::vector<std::string> s;
        std::transform(start, last, std::back_inserter(s), f);
        return "ck::Tuple<" + join_strings(s, ",") + ">";
    }

Paul's avatar
Paul committed
231
232
    static std::vector<shape> adjust_inputs(std::vector<shape> inputs, bool& swap_inputs)
    {
Paul's avatar
Format  
Paul committed
233
        swap_inputs  = false;
Paul's avatar
Paul committed
234
        auto c_shape = inputs.back();
Paul's avatar
Format  
Paul committed
235
        if(not transposed_matrix(c_shape))
Paul's avatar
Paul committed
236
237
238
239
240
241
242
243
244
245
246
            return inputs;
        std::vector<int64_t> perm(c_shape.lens().size());
        std::iota(perm.begin(), perm.end(), 0);
        std::swap(perm[perm.size() - 1], perm[perm.size() - 2]);
        std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](shape s) {
            return reorder_shape(s, perm);
        });
        swap_inputs = true;
        return inputs;
    }

Paul's avatar
Paul committed
247
248
249
250
    std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }

    operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
    {
Paul's avatar
Paul committed
251
252
        auto a_shape = inputs[0];
        auto b_shape = inputs[1];
Paul's avatar
Paul committed
253
        auto c_shape = inputs.back();
Paul's avatar
Paul committed
254

Paul's avatar
Paul committed
255
256
        auto rank = a_shape.lens().size();

Paul's avatar
Paul committed
257
        std::array<char, 3> keys{'M', 'N', 'K'};
Paul's avatar
Format  
Paul committed
258
        std::array<std::size_t, 3> config{
Paul's avatar
Paul committed
259
            c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()};
Paul's avatar
Paul committed
260

Paul's avatar
Format  
Paul committed
261
262
        auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
        auto ip         = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
Paul's avatar
Format  
Paul committed
263
            return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
264
265
                   get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
                   get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
Paul's avatar
Paul committed
266
        })};
Paul's avatar
Paul committed
267
        assert(inputs.size() < 4 or v.contains("post"));
Paul's avatar
Format  
Paul committed
268
        if(v.contains("post"))
Paul's avatar
Paul committed
269
        {
Paul's avatar
Paul committed
270
271
272
            ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
            ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
            ip.set_ds_op(v.at("post").to<std::string>());
Paul's avatar
Paul committed
273
274
        }

Paul's avatar
Paul committed
275
276
        auto padding = ip.get_pad(config);
        std::string gemm_type;
Paul's avatar
Format  
Paul committed
277
        for(auto i : range(padding.size()))
Paul's avatar
Paul committed
278
        {
Paul's avatar
Format  
Paul committed
279
            if(padding[i] != 0)
Paul's avatar
Paul committed
280
281
                gemm_type += keys[i];
        }
Paul's avatar
Format  
Paul committed
282
        if(gemm_type.empty())
Paul's avatar
Paul committed
283
284
285
286
287
            gemm_type = "Default";
        else
            gemm_type += "Padding";
        ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);

Paul's avatar
Paul committed
288
        auto blocks_per_batch = ip.get_grid_size(config);
Paul's avatar
Format  
Paul committed
289
290
291
292
        auto batch_count      = std::accumulate(c_shape.lens().rbegin() + 2,
                                           c_shape.lens().rend(),
                                           std::size_t{1},
                                           std::multiplies<std::size_t>());
Paul's avatar
Paul committed
293

Paul's avatar
Paul committed
294
        hip_compile_options options;
Paul's avatar
Paul committed
295
        auto block_size = ip.get_block_size();
Paul's avatar
Paul committed
296
        auto grid_size  = batch_count * blocks_per_batch;
Paul's avatar
Paul committed
297
        options.set_launch_params(v, grid_size * block_size, block_size);
Paul's avatar
Paul committed
298
        options.inputs         = inputs;
Paul's avatar
Paul committed
299
        options.output         = c_shape;
Paul's avatar
Paul committed
300
        options.kernel_name    = v.get("kernel", "ck_gemm_kernel");
Paul's avatar
Paul committed
301
302
        options.virtual_inputs = inputs;

Paul's avatar
Paul committed
303
        if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
Paul's avatar
Paul committed
304
305
            options.params += " -DMIGRAPHX_CK_CHECK=1";

Paul's avatar
Format  
Paul committed
306
        auto src = interpolate_string(ck_gemm_kernel,
Paul's avatar
Paul committed
307
                                      {{"instance", ip.str()},
Paul's avatar
Format  
Paul committed
308
309
                                       {"params", enum_params(inputs.size(), "void * private_p")},
                                       {"args", enum_params(inputs.size(), "private_p")},
Paul's avatar
Paul committed
310
                                       {"blocks_per_batch", to_string(blocks_per_batch)},
Paul's avatar
Format  
Paul committed
311
312
                                       {"preamble", v.get("preamble", std::string{})},
                                       {"kernel", options.kernel_name}});
Paul's avatar
Format  
Paul committed
313

Paul's avatar
Paul committed
314
315
316
317
318
        return compile_hip_code_object(src, options);
    }

    compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
    {
Paul's avatar
Format  
Paul committed
319
320
        auto v      = op.to_value();
        v["kernel"] = "ck_gemm_kernel";
Paul's avatar
Paul committed
321
322
323
        if(not ins->module_inputs().empty())
        {
            auto* pm      = ins->module_inputs().front();
Paul's avatar
Format  
Paul committed
324
325
326
            v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") +
                            "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
            v["post"]   = "ck_function_adaptor<post_ck_gemm>";
Paul's avatar
Paul committed
327
            v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
Paul's avatar
Format  
Paul committed
328
        }
Paul's avatar
Paul committed
329

Paul's avatar
Paul committed
330
        auto shapes = to_shapes(ins->inputs());
Paul's avatar
Paul committed
331
        return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
Paul's avatar
Format  
Paul committed
332
            if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
Paul's avatar
Paul committed
333
334
335
336
            {
                std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
                std::cout << "ck_gemm: " << to_json_string(to_value(gemm_shapes)) << std::endl;
            }
Paul's avatar
Paul committed
337
        });
Paul's avatar
Paul committed
338
339
340
341
342
343
    }
};

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx