ck_gemm.cpp 14.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <fstream>
#include <filesystem>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>

#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
Paul's avatar
Paul committed
32
#include <migraphx/gpu/compile_gen.hpp>
Paul's avatar
Paul committed
33
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
34
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
35
36
37
38
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
39
#include <migraphx/file_buffer.hpp>
Paul's avatar
Paul committed
40

41
42
43
44
45
46
47
48
49
50
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
#include "ck/library/tensor_operation_instance/solution_instances/gemm_multiple_d_xdlop_cshuffle.hpp"

#include <iostream>

Paul's avatar
Paul committed
51
52
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
Paul's avatar
Paul committed
53

Paul's avatar
Paul committed
54
55
56
57
58
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

namespace gpu {

Paul's avatar
Paul committed
59
60
using namespace migraphx::gpu::gen; // NOLINT

Paul's avatar
Paul committed
61
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
Paul's avatar
Paul committed
62
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
Paul's avatar
Paul committed
63
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
Paul's avatar
Paul committed
64
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
Paul's avatar
Paul committed
65

Paul's avatar
Paul committed
66
67
68
69
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
Paul's avatar
Paul committed
70
#include <migraphx/kernels/pointwise.hpp>
71
#include <ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp>
Paul's avatar
Paul committed
72
73
74

namespace migraphx {

Paul's avatar
Paul committed
75
76
${preamble}

Paul's avatar
Paul committed
77
78
extern "C" {

Paul's avatar
Paul committed
79
__global__ void ${kernel}(${params})
Paul's avatar
Paul committed
80
{
Paul's avatar
Paul committed
81
    transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
82
        ck_gemm<${solution}, ${blocks_per_batch}>(xs...);
Paul's avatar
Paul committed
83
84
85
86
87
88
89
90
91
    });
}

}

} // namespace migraphx

)__migraphx__";

Paul's avatar
Paul committed
92
93
static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }

Paul's avatar
Paul committed
94
struct instance
Paul's avatar
Paul committed
95
{
Paul's avatar
Paul committed
96
97
    std::vector<std::string> params;
    static const std::size_t block_size_index = 15;
Paul's avatar
Paul committed
98

Paul's avatar
Format  
Paul committed
99
    std::size_t int_at(std::size_t i) const { return std::stoull(params[i]); }
Paul's avatar
Paul committed
100

Paul's avatar
Format  
Paul committed
101
    std::size_t get_block_size() const { return int_at(block_size_index); }
Paul's avatar
Paul committed
102
103
104
105
106
107
108
109
110
111

    std::size_t get_pb(std::size_t i) const
    {
        assert(i < 4);
        return int_at(block_size_index + 1 + i);
    }

    std::array<std::size_t, 3> get_pad(const std::array<std::size_t, 3>& config) const
    {
        std::array<std::size_t, 3> result{};
Paul's avatar
Format  
Paul committed
112
        for(auto i : range(config.size()))
Paul's avatar
Paul committed
113
114
115
116
117
118
119
120
        {
            result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i];
        }
        return result;
    }

    std::size_t get_grid_size(const std::array<std::size_t, 3>& config) const
    {
Paul's avatar
Paul committed
121
        return int_div_ceil(config[0], get_pb(0)) * int_div_ceil(config[1], get_pb(1));
Paul's avatar
Paul committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    }

    void set_ds_layout(const std::string& s)
    {
        assert(params[2] == "ck::Tuple<>");
        params[2] = s;
    }

    void set_ds_type(const std::string& s)
    {
        assert(params[8] == "ck::Tuple<>");
        params[8] = s;
    }

    void set_ds_op(const std::string& s)
    {
        assert(params[12] == "ck_passthrough");
        params[12] = s;
    }

    void set_gemm(const std::string& s)
    {
        assert(params[13] == "ck::tensor_operation::device::GemmSpecialization::Default");
        params[13] = s;
    }

Paul's avatar
Format  
Paul committed
148
    std::string str() const { return join_strings(params, ","); }
Paul's avatar
Paul committed
149
};
Paul's avatar
Paul committed
150

Paul's avatar
Paul committed
151
152
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }

Paul's avatar
Format  
Paul committed
153
template <class F, class Action>
Paul's avatar
Paul committed
154
155
156
157
158
159
160
161
auto action_decorate(F f, Action action)
{
    return [=](auto&&... xs) {
        action();
        f(std::forward<decltype(xs)>(xs)...);
    };
}

Paul's avatar
Paul committed
162
163
164
using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s)
{
Paul's avatar
Format  
Paul committed
165
    if(not fs::exists(s))
Paul's avatar
Paul committed
166
167
168
169
        return {};
    return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}

Paul's avatar
Paul committed
170
171
static float matrix_distance(const shape& x, const shape& y)
{
Paul's avatar
Format  
Paul committed
172
    if(x.type() != y.type())
Paul's avatar
Paul committed
173
        return std::numeric_limits<float>::max();
Paul's avatar
Format  
Paul committed
174
    if(transposed_matrix(x) != transposed_matrix(y))
Paul's avatar
Paul committed
175
        return std::numeric_limits<float>::max();
Paul's avatar
Format  
Paul committed
176
177
178
179
180
181
    auto sum_squared = std::inner_product(x.lens().rbegin(),
                                          x.lens().rbegin() + 2,
                                          y.lens().rbegin(),
                                          0,
                                          std::plus<>{},
                                          [](auto a, auto b) { return (a - b) * (a - b); });
Paul's avatar
Paul committed
182
183
184
    return std::sqrt(sum_squared);
}

Paul's avatar
Paul committed
185
186
187
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
    static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
Paul's avatar
Format  
Paul committed
188
    if(tuning.empty())
Paul's avatar
Paul committed
189
        std::cout << "*********** Warning: No CK tuning!" << std::endl;
Paul's avatar
Format  
Paul committed
190
    auto it = std::find_if(
Paul's avatar
Format  
Paul committed
191
        tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
Paul's avatar
Format  
Paul committed
192
193
    if(it == tuning.end())
    {
Paul's avatar
Paul committed
194
        std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
Paul's avatar
Paul committed
195
196
        std::vector<std::pair<float, std::size_t>> w;
        std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
Paul's avatar
Format  
Paul committed
197
            if(inputs.size() < 3 or p.first.size() < 3)
Paul's avatar
Paul committed
198
                MIGRAPHX_THROW("Invalid CK config");
Paul's avatar
Format  
Paul committed
199
200
201
202
203
204
205
            auto avg_distance = std::inner_product(
                p.first.begin(),
                p.first.begin() + 3,
                inputs.begin(),
                0.0f,
                std::plus<>{},
                [](const auto& x, const auto& y) { return matrix_distance(x, y) / 3.0f; });
Paul's avatar
Paul committed
206
207
208
209
            return std::make_pair(avg_distance, p.second);
        });
        std::sort(w.begin(), w.end());
        std::size_t default_value = 4;
Paul's avatar
Format  
Paul committed
210
        if(not w.empty())
Paul's avatar
Paul committed
211
            default_value = w.front().second;
Paul's avatar
Paul committed
212
213
214
        auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
        std::cout << "*********** Warning: CK try tuning: " << tuning_val << std::endl;
        return tuning_val;
Paul's avatar
Paul committed
215
    }
Paul's avatar
Paul committed
216
217
218
    return it->second;
}

Paul's avatar
Paul committed
219
220
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
Paul's avatar
Paul committed
221
222
    static std::string get_layout(const shape& s)
    {
Paul's avatar
Paul committed
223
        return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
Paul's avatar
Format  
Paul committed
224
                                    : "ck::tensor_layout::gemm::RowMajor";
Paul's avatar
Paul committed
225
226
227
    }

    static std::string get_type(const shape& s)
Paul's avatar
Paul committed
228
    {
Paul's avatar
Format  
Paul committed
229
        if(s.type() == shape::half_type)
Paul's avatar
Paul committed
230
231
232
            return "ck::half_t";
        return shape::cpp_type(s.type());
    }
Paul's avatar
Paul committed
233

Paul's avatar
Format  
Paul committed
234
    template <class Iterator, class F>
Paul's avatar
Paul committed
235
236
237
238
239
240
241
    static std::string ck_tuple(Iterator start, Iterator last, F f)
    {
        std::vector<std::string> s;
        std::transform(start, last, std::back_inserter(s), f);
        return "ck::Tuple<" + join_strings(s, ",") + ">";
    }

Paul's avatar
Paul committed
242
243
    static std::vector<shape> adjust_inputs(std::vector<shape> inputs, bool& swap_inputs)
    {
Paul's avatar
Format  
Paul committed
244
        swap_inputs  = false;
Paul's avatar
Paul committed
245
        auto c_shape = inputs.back();
Paul's avatar
Format  
Paul committed
246
        if(not transposed_matrix(c_shape))
Paul's avatar
Paul committed
247
248
249
250
251
252
253
254
255
256
257
            return inputs;
        std::vector<int64_t> perm(c_shape.lens().size());
        std::iota(perm.begin(), perm.end(), 0);
        std::swap(perm[perm.size() - 1], perm[perm.size() - 2]);
        std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](shape s) {
            return reorder_shape(s, perm);
        });
        swap_inputs = true;
        return inputs;
    }

258
259
    static std::size_t get_batch_count(const shape& s)
    {
Alan Turner's avatar
Alan Turner committed
260
261
        return std::accumulate(
            s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
262
263
264
265
266
    }

    static void fold_batch_dims(shape& s)
    {
        auto lens = s.lens();
Alan Turner's avatar
Alan Turner committed
267
        if(lens.size() <= 2)
268
269
            return;
        auto batch_count = get_batch_count(s);
Alan Turner's avatar
Alan Turner committed
270
271
272
        auto m1          = lens.at(lens.size() - 2);
        auto m2          = lens.at(lens.size() - 1);
        if(transposed_matrix(s))
273
274
275
276
277
278
279
280
            s = shape{s.type(), {m1, m2 * batch_count}};
        else
            s = shape{s.type(), {m1 * batch_count, m2}};
    }

    static void remove_batch_dims(shape& s)
    {
        auto lens = s.lens();
Alan Turner's avatar
Alan Turner committed
281
        if(lens.size() <= 2)
282
283
284
            return;
        auto m1 = lens.at(lens.size() - 2);
        auto m2 = lens.at(lens.size() - 1);
Alan Turner's avatar
Alan Turner committed
285
        s       = shape{s.type(), {m1, m2}};
286
287
    }

Paul's avatar
Paul committed
288
289
290
291
    std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }

    operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
    {
Alan Turner's avatar
Alan Turner committed
292
293
294
        auto a_shape      = inputs[0];
        auto b_shape      = inputs[1];
        auto c_shape      = inputs.back();
295
        auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
Paul's avatar
Paul committed
296

Alan Turner's avatar
Alan Turner committed
297
298
        auto rank           = a_shape.lens().size();
        auto b_strides      = b_shape.strides();
299
300
        bool can_fold_batch = rank >= 3 and b_strides[rank - 3] == 0;

Alan Turner's avatar
Alan Turner committed
301
302
303
304
305
        auto batch_count = get_batch_count(c_shape);
        auto m           = c_shape.lens()[rank - 2];
        m                = can_fold_batch ? m * batch_count : m;
        auto n           = c_shape.lens().back();
        auto k           = a_shape.lens().back();
306
307

        const auto numDTensors = inputs.size() - 3;
Alan Turner's avatar
Alan Turner committed
308
309
310
311
312
313
314
        const bool transA      = transposed_matrix(a_shape);
        const bool transB      = transposed_matrix(b_shape);
        const bool transCDE    = transposed_matrix(c_shape);
        const auto a_type      = get_type(a_shape);
        const auto b_type      = get_type(b_shape);
        const auto cde_type =
            ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type); // get_type(c_shape);
315
316
        const auto cde_layout = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout);

Alan Turner's avatar
Alan Turner committed
317
318
        std::string ck_passthrough =
            "ck_passthrough"; //"ck::tensor_operation::element_wise::PassThrough";
319
        std::string cde_op = ck_passthrough;
Paul's avatar
Paul committed
320
        assert(inputs.size() < 4 or v.contains("post"));
Paul's avatar
Format  
Paul committed
321
        if(v.contains("post"))
Paul's avatar
Paul committed
322
        {
323
            cde_op = v.at("post").to<std::string>();
Paul's avatar
Paul committed
324
325
        }

Alan Turner's avatar
Alan Turner committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        auto problem =
            ck::tensor_operation::device::instance::Problem{static_cast<ck::index_t>(m),
                                                            static_cast<ck::index_t>(n),
                                                            static_cast<ck::index_t>(k),
                                                            static_cast<ck::index_t>(numDTensors),
                                                            transA,
                                                            transB,
                                                            transCDE,
                                                            a_type,
                                                            b_type,
                                                            cde_type,
                                                            ck_passthrough,
                                                            ck_passthrough,
                                                            cde_op,
                                                            cde_layout};
Alan Turner's avatar
Alan Turner committed
341
342
343
344
345
        const auto solutions         = problem.GetSolutions();
        const auto solution = solutions.at(tuning_value);
        const auto template_str  = solution.GetStr();
        const auto blocks_per_batch = solution.GetGridSize();
        const auto block_size       = solution.GetBlockSize();
Paul's avatar
Paul committed
346

Paul's avatar
Paul committed
347
        hip_compile_options options;
Alan Turner's avatar
Alan Turner committed
348
        auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
Paul's avatar
Paul committed
349
        options.set_launch_params(v, grid_size * block_size, block_size);
Paul's avatar
Paul committed
350
        options.inputs         = inputs;
Paul's avatar
Paul committed
351
        options.output         = c_shape;
Paul's avatar
Paul committed
352
        options.kernel_name    = v.get("kernel", "ck_gemm_kernel");
Paul's avatar
Paul committed
353
        options.virtual_inputs = inputs;
Alan Turner's avatar
Alan Turner committed
354
        if(can_fold_batch)
355
356
357
358
359
360
361
        {
            auto vinputs = inputs;
            fold_batch_dims(vinputs[0]);
            remove_batch_dims(vinputs[1]);
            std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims);
            options.virtual_inputs = vinputs;
        }
Paul's avatar
Paul committed
362

Paul's avatar
Paul committed
363
        if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
Paul's avatar
Paul committed
364
            options.params += " -DMIGRAPHX_CK_CHECK=1";
Alan Turner's avatar
Alan Turner committed
365

Paul's avatar
Format  
Paul committed
366
        auto src = interpolate_string(ck_gemm_kernel,
Alan Turner's avatar
Alan Turner committed
367
                                      {{"solution", template_str},
Paul's avatar
Format  
Paul committed
368
369
                                       {"params", enum_params(inputs.size(), "void * private_p")},
                                       {"args", enum_params(inputs.size(), "private_p")},
Paul's avatar
Paul committed
370
                                       {"blocks_per_batch", to_string(blocks_per_batch)},
Paul's avatar
Format  
Paul committed
371
372
                                       {"preamble", v.get("preamble", std::string{})},
                                       {"kernel", options.kernel_name}});
373
        std::cout << src << std::endl;
Paul's avatar
Paul committed
374
375
376
377
378
        return compile_hip_code_object(src, options);
    }

    compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
    {
Paul's avatar
Format  
Paul committed
379
380
        auto v      = op.to_value();
        v["kernel"] = "ck_gemm_kernel";
Paul's avatar
Paul committed
381
382
383
        if(not ins->module_inputs().empty())
        {
            auto* pm      = ins->module_inputs().front();
Paul's avatar
Format  
Paul committed
384
385
386
            v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") +
                            "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
            v["post"]   = "ck_function_adaptor<post_ck_gemm>";
Paul's avatar
Paul committed
387
            v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
Paul's avatar
Format  
Paul committed
388
        }
Paul's avatar
Paul committed
389

Paul's avatar
Paul committed
390
        auto shapes = to_shapes(ins->inputs());
Paul's avatar
Paul committed
391
        return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
Paul's avatar
Format  
Paul committed
392
            if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
Paul's avatar
Paul committed
393
394
395
396
            {
                std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
                std::cout << "ck_gemm: " << to_json_string(to_value(gemm_shapes)) << std::endl;
            }
Paul's avatar
Paul committed
397
        });
Paul's avatar
Paul committed
398
399
400
401
402
403
    }
};

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx