ck_gemm.cpp 14.6 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <fstream>
#include <filesystem>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>

#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
Paul's avatar
Paul committed
32
#include <migraphx/gpu/compile_gen.hpp>
Paul's avatar
Paul committed
33
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
34
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
35
36
37
38
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
39
#include <migraphx/file_buffer.hpp>
Paul's avatar
Paul committed
40

Paul's avatar
Paul committed
41
42
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
Paul's avatar
Paul committed
43

Paul's avatar
Paul committed
44
45
46
47
48
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

namespace gpu {

Paul's avatar
Paul committed
49
50
using namespace migraphx::gpu::gen; // NOLINT

Paul's avatar
Paul committed
51
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
Paul's avatar
Paul committed
52
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
Paul's avatar
Paul committed
53
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
Paul's avatar
Paul committed
54
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
Paul's avatar
Paul committed
55

Paul's avatar
Paul committed
56
57
58
59
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
Paul's avatar
Paul committed
60
#include <migraphx/kernels/pointwise.hpp>
Alan Turner's avatar
Alan Turner committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp"

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using Empty_Tuple   = ck::Tuple<>;

using GEMM = ck::tensor_operation::device::DeviceGemmMultipleD_Dl<    
	Row,
	Row,
	Empty_Tuple,
	Row,
	int8_t,
	int8_t,
	int32_t,
	Empty_Tuple,
	int8_t, //EDataType
	PassThrough,
	PassThrough,
	PassThrough,
	ck::tensor_operation::device::GemmSpecialization::MNKPadding,
	256,
	128,
	128,
	16,
	4,
	4,
	4,
	1,
	S<8,2>,
	S<8,2>,
	S<8,1,1,4>,
	S<2,1,128,1>,
	S<1,2,0,3>,
	S<1,2,0,3>,
	S<4,1,1,4>,
	S<1,2,0,3>,
	S<1,1,1,4>,
	S<2,1,4,4>,
	S<8,1,32,1>,
	S<0,3,1,2>,
	S<0,3,1,2>,
	S<1,1,4,1>,
	S<0,3,1,2>,
	S<1,1,4,4>,
	S<0,1,2,3,4,5>,
	5,
	4>;
Paul's avatar
Paul committed
115
116
117

namespace migraphx {

Paul's avatar
Paul committed
118
119
${preamble}

Paul's avatar
Paul committed
120
121
extern "C" {

Paul's avatar
Paul committed
122
__global__ void ${kernel}(${params})
Paul's avatar
Paul committed
123
{
Paul's avatar
Paul committed
124
    transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
Alan Turner's avatar
Alan Turner committed
125
        ck_gemm<GEMM, ${blocks_per_batch}>(xs...);
Paul's avatar
Paul committed
126
127
128
129
130
131
132
133
134
    });
}

}

} // namespace migraphx

)__migraphx__";

Paul's avatar
Paul committed
135
136
static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }

Paul's avatar
Paul committed
137
struct instance
Paul's avatar
Paul committed
138
{
Paul's avatar
Paul committed
139
140
    std::vector<std::string> params;
    static const std::size_t block_size_index = 15;
Paul's avatar
Paul committed
141

Paul's avatar
Format  
Paul committed
142
    std::size_t int_at(std::size_t i) const { return std::stoull(params[i]); }
Paul's avatar
Paul committed
143

Paul's avatar
Format  
Paul committed
144
    std::size_t get_block_size() const { return int_at(block_size_index); }
Paul's avatar
Paul committed
145
146
147
148
149
150
151
152
153
154

    std::size_t get_pb(std::size_t i) const
    {
        assert(i < 4);
        return int_at(block_size_index + 1 + i);
    }

    std::array<std::size_t, 3> get_pad(const std::array<std::size_t, 3>& config) const
    {
        std::array<std::size_t, 3> result{};
Paul's avatar
Format  
Paul committed
155
        for(auto i : range(config.size()))
Paul's avatar
Paul committed
156
157
158
159
160
161
162
163
        {
            result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i];
        }
        return result;
    }

    std::size_t get_grid_size(const std::array<std::size_t, 3>& config) const
    {
Paul's avatar
Paul committed
164
        return int_div_ceil(config[0], get_pb(0)) * int_div_ceil(config[1], get_pb(1));
Paul's avatar
Paul committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    }

    void set_ds_layout(const std::string& s)
    {
        assert(params[2] == "ck::Tuple<>");
        params[2] = s;
    }

    void set_ds_type(const std::string& s)
    {
        assert(params[8] == "ck::Tuple<>");
        params[8] = s;
    }

    void set_ds_op(const std::string& s)
    {
        assert(params[12] == "ck_passthrough");
        params[12] = s;
    }

    void set_gemm(const std::string& s)
    {
        assert(params[13] == "ck::tensor_operation::device::GemmSpecialization::Default");
        params[13] = s;
    }

Paul's avatar
Format  
Paul committed
191
    std::string str() const { return join_strings(params, ","); }
Paul's avatar
Paul committed
192
};
Paul's avatar
Paul committed
193

Paul's avatar
Paul committed
194
195
static bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }

Paul's avatar
Format  
Paul committed
196
template <class F, class Action>
Paul's avatar
Paul committed
197
198
199
200
201
202
203
204
auto action_decorate(F f, Action action)
{
    return [=](auto&&... xs) {
        action();
        f(std::forward<decltype(xs)>(xs)...);
    };
}

Paul's avatar
Paul committed
205
206
207
using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s)
{
Paul's avatar
Format  
Paul committed
208
    if(not fs::exists(s))
Paul's avatar
Paul committed
209
210
211
212
        return {};
    return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}

Paul's avatar
Paul committed
213
214
static float matrix_distance(const shape& x, const shape& y)
{
Paul's avatar
Format  
Paul committed
215
    if(x.type() != y.type())
Paul's avatar
Paul committed
216
        return std::numeric_limits<float>::max();
Paul's avatar
Format  
Paul committed
217
    if(transposed_matrix(x) != transposed_matrix(y))
Paul's avatar
Paul committed
218
        return std::numeric_limits<float>::max();
Paul's avatar
Format  
Paul committed
219
220
221
222
223
224
    auto sum_squared = std::inner_product(x.lens().rbegin(),
                                          x.lens().rbegin() + 2,
                                          y.lens().rbegin(),
                                          0,
                                          std::plus<>{},
                                          [](auto a, auto b) { return (a - b) * (a - b); });
Paul's avatar
Paul committed
225
226
227
    return std::sqrt(sum_squared);
}

Paul's avatar
Paul committed
228
229
230
static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
    static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
Paul's avatar
Format  
Paul committed
231
    if(tuning.empty())
Paul's avatar
Paul committed
232
        std::cout << "*********** Warning: No CK tuning!" << std::endl;
Paul's avatar
Format  
Paul committed
233
    auto it = std::find_if(
Paul's avatar
Format  
Paul committed
234
        tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
Paul's avatar
Format  
Paul committed
235
236
    if(it == tuning.end())
    {
Paul's avatar
Paul committed
237
        std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
Paul's avatar
Paul committed
238
239
        std::vector<std::pair<float, std::size_t>> w;
        std::transform(tuning.begin(), tuning.end(), std::back_inserter(w), [&](const auto& p) {
Paul's avatar
Format  
Paul committed
240
            if(inputs.size() < 3 or p.first.size() < 3)
Paul's avatar
Paul committed
241
                MIGRAPHX_THROW("Invalid CK config");
Paul's avatar
Format  
Paul committed
242
243
244
245
246
247
248
            auto avg_distance = std::inner_product(
                p.first.begin(),
                p.first.begin() + 3,
                inputs.begin(),
                0.0f,
                std::plus<>{},
                [](const auto& x, const auto& y) { return matrix_distance(x, y) / 3.0f; });
Paul's avatar
Paul committed
249
250
251
252
            return std::make_pair(avg_distance, p.second);
        });
        std::sort(w.begin(), w.end());
        std::size_t default_value = 4;
Paul's avatar
Format  
Paul committed
253
        if(not w.empty())
Paul's avatar
Paul committed
254
            default_value = w.front().second;
Paul's avatar
Paul committed
255
256
257
        auto tuning_val = value_of(MIGRAPHX_CK_TUNING_VALUE{}, default_value);
        std::cout << "*********** Warning: CK try tuning: " << tuning_val << std::endl;
        return tuning_val;
Paul's avatar
Paul committed
258
    }
Paul's avatar
Paul committed
259
260
261
    return it->second;
}

Paul's avatar
Paul committed
262
263
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
Paul's avatar
Paul committed
264
265
    static std::string get_layout(const shape& s)
    {
Paul's avatar
Paul committed
266
        return transposed_matrix(s) ? "ck::tensor_layout::gemm::ColumnMajor"
Paul's avatar
Format  
Paul committed
267
                                    : "ck::tensor_layout::gemm::RowMajor";
Paul's avatar
Paul committed
268
269
270
    }

    static std::string get_type(const shape& s)
Paul's avatar
Paul committed
271
    {
Paul's avatar
Format  
Paul committed
272
        if(s.type() == shape::half_type)
Paul's avatar
Paul committed
273
274
275
            return "ck::half_t";
        return shape::cpp_type(s.type());
    }
Paul's avatar
Paul committed
276

Paul's avatar
Format  
Paul committed
277
    template <class Iterator, class F>
Paul's avatar
Paul committed
278
279
280
281
282
283
284
    static std::string ck_tuple(Iterator start, Iterator last, F f)
    {
        std::vector<std::string> s;
        std::transform(start, last, std::back_inserter(s), f);
        return "ck::Tuple<" + join_strings(s, ",") + ">";
    }

Paul's avatar
Paul committed
285
286
    static std::vector<shape> adjust_inputs(std::vector<shape> inputs, bool& swap_inputs)
    {
Paul's avatar
Format  
Paul committed
287
        swap_inputs  = false;
Paul's avatar
Paul committed
288
        auto c_shape = inputs.back();
Paul's avatar
Format  
Paul committed
289
        if(not transposed_matrix(c_shape))
Paul's avatar
Paul committed
290
291
292
293
294
295
296
297
298
299
300
            return inputs;
        std::vector<int64_t> perm(c_shape.lens().size());
        std::iota(perm.begin(), perm.end(), 0);
        std::swap(perm[perm.size() - 1], perm[perm.size() - 2]);
        std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](shape s) {
            return reorder_shape(s, perm);
        });
        swap_inputs = true;
        return inputs;
    }

301
302
    static std::size_t get_batch_count(const shape& s)
    {
Alan Turner's avatar
Alan Turner committed
303
304
        return std::accumulate(
            s.lens().rbegin() + 2, s.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
305
306
307
308
309
    }

    static void fold_batch_dims(shape& s)
    {
        auto lens = s.lens();
Alan Turner's avatar
Alan Turner committed
310
        if(lens.size() <= 2)
311
312
            return;
        auto batch_count = get_batch_count(s);
Alan Turner's avatar
Alan Turner committed
313
314
315
        auto m1          = lens.at(lens.size() - 2);
        auto m2          = lens.at(lens.size() - 1);
        if(transposed_matrix(s))
316
317
318
319
320
321
322
323
            s = shape{s.type(), {m1, m2 * batch_count}};
        else
            s = shape{s.type(), {m1 * batch_count, m2}};
    }

    static void remove_batch_dims(shape& s)
    {
        auto lens = s.lens();
Alan Turner's avatar
Alan Turner committed
324
        if(lens.size() <= 2)
325
326
327
            return;
        auto m1 = lens.at(lens.size() - 2);
        auto m2 = lens.at(lens.size() - 1);
Alan Turner's avatar
Alan Turner committed
328
        s       = shape{s.type(), {m1, m2}};
329
330
    }

Paul's avatar
Paul committed
331
332
333
334
    std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }

    operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
    {
Paul's avatar
Paul committed
335
336
        auto a_shape = inputs[0];
        auto b_shape = inputs[1];
Paul's avatar
Paul committed
337
        auto c_shape = inputs.back();
Paul's avatar
Paul committed
338

Alan Turner's avatar
Alan Turner committed
339
340
        auto rank           = a_shape.lens().size();
        auto b_strides      = b_shape.strides();
341
342
        bool can_fold_batch = rank >= 3 and b_strides[rank - 3] == 0;

Alan Turner's avatar
Alan Turner committed
343
344
345
346
347
        auto batch_count = get_batch_count(c_shape);
        auto m           = c_shape.lens()[rank - 2];
        m                = can_fold_batch ? m * batch_count : m;
        auto n           = c_shape.lens().back();
        auto k           = a_shape.lens().back();
Paul's avatar
Paul committed
348
        std::array<char, 3> keys{'M', 'N', 'K'};
349
        std::array<std::size_t, 3> config{m, n, k};
Paul's avatar
Format  
Paul committed
350
351
        auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
        auto ip         = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
Alan Turner's avatar
Alan Turner committed
352
            return true;/* get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
353
                   get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
Alan Turner's avatar
Alan Turner committed
354
                   get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; */
Paul's avatar
Paul committed
355
        })};
Paul's avatar
Paul committed
356
        assert(inputs.size() < 4 or v.contains("post"));
Paul's avatar
Format  
Paul committed
357
        if(v.contains("post"))
Paul's avatar
Paul committed
358
        {
Paul's avatar
Paul committed
359
360
361
            ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
            ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
            ip.set_ds_op(v.at("post").to<std::string>());
Paul's avatar
Paul committed
362
363
        }

Paul's avatar
Paul committed
364
365
        auto padding = ip.get_pad(config);
        std::string gemm_type;
Paul's avatar
Format  
Paul committed
366
        for(auto i : range(padding.size()))
Paul's avatar
Paul committed
367
        {
Paul's avatar
Format  
Paul committed
368
            if(padding[i] != 0)
Paul's avatar
Paul committed
369
370
                gemm_type += keys[i];
        }
Paul's avatar
Format  
Paul committed
371
        if(gemm_type.empty())
Paul's avatar
Paul committed
372
373
374
375
376
            gemm_type = "Default";
        else
            gemm_type += "Padding";
        ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);

Alan Turner's avatar
Alan Turner committed
377
        auto blocks_per_batch = int_div_ceil(m, 128) * int_div_ceil(n, 128);;//ip.get_grid_size(config);
Paul's avatar
Paul committed
378

Paul's avatar
Paul committed
379
        hip_compile_options options;
Alan Turner's avatar
Alan Turner committed
380
        auto block_size = 256;//ip.get_block_size();
381
        auto grid_size  = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
Paul's avatar
Paul committed
382
        options.set_launch_params(v, grid_size * block_size, block_size);
Alan Turner's avatar
Alan Turner committed
383
384
385
386
387
        //auto new_inputs = inputs;
        auto new_inputs = inputs;
        // auto out_s = inputs.back();
        // new_inputs.back() = shape{shape::int8_type, out_s.lens(), out_s.strides()};
        options.inputs         = new_inputs;
Paul's avatar
Paul committed
388
        options.output         = c_shape;
Paul's avatar
Paul committed
389
        options.kernel_name    = v.get("kernel", "ck_gemm_kernel");
Alan Turner's avatar
Alan Turner committed
390
        options.virtual_inputs = new_inputs;
Alan Turner's avatar
Alan Turner committed
391
        if(can_fold_batch)
392
        {
Alan Turner's avatar
Alan Turner committed
393
            auto vinputs = new_inputs;
394
395
396
397
398
            fold_batch_dims(vinputs[0]);
            remove_batch_dims(vinputs[1]);
            std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims);
            options.virtual_inputs = vinputs;
        }
Paul's avatar
Paul committed
399

Paul's avatar
Paul committed
400
        if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
Paul's avatar
Paul committed
401
402
            options.params += " -DMIGRAPHX_CK_CHECK=1";

Paul's avatar
Format  
Paul committed
403
        auto src = interpolate_string(ck_gemm_kernel,
Paul's avatar
Paul committed
404
                                      {{"instance", ip.str()},
Paul's avatar
Format  
Paul committed
405
406
                                       {"params", enum_params(inputs.size(), "void * private_p")},
                                       {"args", enum_params(inputs.size(), "private_p")},
Paul's avatar
Paul committed
407
                                       {"blocks_per_batch", to_string(blocks_per_batch)},
Paul's avatar
Format  
Paul committed
408
409
                                       {"preamble", v.get("preamble", std::string{})},
                                       {"kernel", options.kernel_name}});
Paul's avatar
Format  
Paul committed
410

Paul's avatar
Paul committed
411
412
413
414
415
        return compile_hip_code_object(src, options);
    }

    compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
    {
Paul's avatar
Format  
Paul committed
416
417
        auto v      = op.to_value();
        v["kernel"] = "ck_gemm_kernel";
Paul's avatar
Paul committed
418
419
420
        if(not ins->module_inputs().empty())
        {
            auto* pm      = ins->module_inputs().front();
Paul's avatar
Format  
Paul committed
421
422
423
            v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") +
                            "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
            v["post"]   = "ck_function_adaptor<post_ck_gemm>";
Paul's avatar
Paul committed
424
            v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
Paul's avatar
Format  
Paul committed
425
        }
Paul's avatar
Paul committed
426

Paul's avatar
Paul committed
427
        auto shapes = to_shapes(ins->inputs());
Paul's avatar
Paul committed
428
        return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
Paul's avatar
Format  
Paul committed
429
            if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
Paul's avatar
Paul committed
430
431
432
433
            {
                std::vector<shape> gemm_shapes{shapes[0], shapes[1], shapes.back()};
                std::cout << "ck_gemm: " << to_json_string(to_value(gemm_shapes)) << std::endl;
            }
Paul's avatar
Paul committed
434
        });
Paul's avatar
Paul committed
435
436
437
438
439
440
    }
};

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx