ck_gemm.cpp 6.72 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <fstream>
#include <filesystem>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>

#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
Paul's avatar
Paul committed
33
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
34
35
36
37
38
39
40
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/env.hpp>
Paul's avatar
Paul committed
41
#include <migraphx/file_buffer.hpp>
Paul's avatar
Paul committed
42

Paul's avatar
Paul committed
43
44
const std::vector<std::string>&
get_instance(std::size_t i, const std::function<bool(const std::vector<std::string>&)>& pred);
Paul's avatar
Paul committed
45

Paul's avatar
Paul committed
46
47
48
49
50
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

namespace gpu {

Paul's avatar
Paul committed
51
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
Paul's avatar
Paul committed
52
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
Paul's avatar
Paul committed
53

Paul's avatar
Paul committed
54
55
56
57
58
59
60
61
62
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>

#include <hip/hip_runtime_api.h>

namespace migraphx {

Paul's avatar
Paul committed
63
using gemm_t = CKDeviceGemm<${instance}>;
Paul's avatar
Paul committed
64
65
66
67
68

extern "C" {

__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
{
Paul's avatar
Paul committed
69
70
    make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) {
        ck_gemm<gemm_t>(a, b, c);
Paul's avatar
Paul committed
71
72
73
74
75
76
77
78
79
    });
}

}

} // namespace migraphx

)__migraphx__";

Paul's avatar
Paul committed
80
81
82
static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }

static std::size_t block_size_index = 13;
Paul's avatar
Paul committed
83

Paul's avatar
Paul committed
84
static std::size_t get_block_size(const std::vector<std::string>& s)
Paul's avatar
Paul committed
85
{
Paul's avatar
Paul committed
86
    return std::stoull(s[block_size_index]);
Paul's avatar
Paul committed
87
88
}

Paul's avatar
Paul committed
89
static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t m, std::size_t n)
Paul's avatar
Paul committed
90
{
Paul's avatar
Format  
Paul committed
91
92
    auto mpb = std::stoull(s[block_size_index + 1]);
    auto npb = std::stoull(s[block_size_index + 2]);
Paul's avatar
Paul committed
93
94
    return int_div_ceil(m, mpb) * int_div_ceil(n, npb);
}
Paul's avatar
Paul committed
95

Paul's avatar
Format  
Paul committed
96
template <class F, class Action>
Paul's avatar
Paul committed
97
98
99
100
101
102
103
104
auto action_decorate(F f, Action action)
{
    return [=](auto&&... xs) {
        action();
        f(std::forward<decltype(xs)>(xs)...);
    };
}

Paul's avatar
Paul committed
105
106
107
using tuning_entry = std::pair<std::vector<shape>, size_t>;
static std::vector<tuning_entry> read_tuning(const std::string& s)
{
Paul's avatar
Format  
Paul committed
108
    if(not fs::exists(s))
Paul's avatar
Paul committed
109
110
111
112
113
114
115
        return {};
    return from_value<std::vector<tuning_entry>>(from_json_string(read_string(s)));
}

static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
    static auto tuning = read_tuning(string_value_of(MIGRAPHX_CK_TUNING{}, ""));
Paul's avatar
Format  
Paul committed
116
    if(tuning.empty())
Paul's avatar
Paul committed
117
        std::cout << "*********** Warning: No CK tuning!" << std::endl;
Paul's avatar
Format  
Paul committed
118
    auto it = std::find_if(
Paul's avatar
Format  
Paul committed
119
        tuning.begin(), tuning.end(), [&](const auto& p) { return p.first == inputs; });
Paul's avatar
Format  
Paul committed
120
121
    if(it == tuning.end())
    {
Paul's avatar
Paul committed
122
        std::cout << "*********** Warning: CK tuning missing for config!" << std::endl;
Paul's avatar
Paul committed
123
        return 4;
Paul's avatar
Paul committed
124
    }
Paul's avatar
Paul committed
125
126
127
    return it->second;
}

Paul's avatar
Paul committed
128
129
struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
Paul's avatar
Paul committed
130
131
    static std::string get_layout(const shape& s)
    {
Paul's avatar
Format  
Paul committed
132
133
        return s.transposed() ? "ck::tensor_layout::gemm::ColumnMajor"
                              : "ck::tensor_layout::gemm::RowMajor";
Paul's avatar
Paul committed
134
135
136
    }

    static std::string get_type(const shape& s)
Paul's avatar
Paul committed
137
    {
Paul's avatar
Format  
Paul committed
138
        if(s.type() == shape::half_type)
Paul's avatar
Paul committed
139
140
141
            return "ck::half_t";
        return shape::cpp_type(s.type());
    }
Paul's avatar
Paul committed
142
143
144
145
146

    std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }

    operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
    {
Paul's avatar
Paul committed
147
148
149
        auto a_shape = inputs[0];
        auto b_shape = inputs[1];
        auto c_shape = inputs[2];
Paul's avatar
Paul committed
150

Paul's avatar
Format  
Paul committed
151
152
153
154
155
156
        auto m  = c_shape.lens().front();
        auto n  = c_shape.lens().back();
        auto k  = a_shape.lens().back();
        auto sa = a_shape.strides().front();
        auto sb = b_shape.strides().front();
        auto sc = c_shape.strides().front();
Paul's avatar
Paul committed
157

Paul's avatar
Format  
Paul committed
158
        auto i               = v.get("tuning_val", get_tuning_for(inputs));
Paul's avatar
Paul committed
159
        const auto& instance = get_instance(i, [&](const auto& x) -> bool {
Paul's avatar
Format  
Paul committed
160
161
162
            return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
                   get_layout(c_shape) == x[2] and get_type(a_shape) == x[3] and
                   get_type(b_shape) == x[4] and get_type(c_shape) == x[5];
Paul's avatar
Paul committed
163
        });
Paul's avatar
Paul committed
164

Paul's avatar
Paul committed
165
        hip_compile_options options;
Paul's avatar
Paul committed
166
        auto block_size = get_block_size(instance);
Paul's avatar
Format  
Paul committed
167
        auto grid_size  = get_grid_size(instance, m, n);
Paul's avatar
Paul committed
168
        options.set_launch_params(v, grid_size * block_size, block_size);
Paul's avatar
Paul committed
169
        options.inputs         = inputs;
Paul's avatar
Paul committed
170
        options.output         = c_shape;
Paul's avatar
Paul committed
171
172
173
        options.kernel_name    = "ck_gemm_kernel";
        options.virtual_inputs = inputs;

Paul's avatar
Format  
Paul committed
174
        auto src = interpolate_string(ck_gemm_kernel,
Paul's avatar
Paul committed
175
                                      {{"instance", join_strings(instance, ",")},
Paul's avatar
Format  
Paul committed
176
177
178
179
180
181
182
                                       {"m", to_string(m)},
                                       {"k", to_string(k)},
                                       {"n", to_string(n)},
                                       {"sa", to_string(sa)},
                                       {"sb", to_string(sb)},
                                       {"sc", to_string(sc)}});

Paul's avatar
Paul committed
183
184
185
186
187
        return compile_hip_code_object(src, options);
    }

    compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
    {
Paul's avatar
Paul committed
188
189
        auto shapes = to_shapes(ins->inputs());
        return action_decorate(replace(compile_op(ctx, shapes, op.to_value())), [=] {
Paul's avatar
Format  
Paul committed
190
            if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
Paul's avatar
Paul committed
191
192
                std::cout << "ck_gemm: " << to_json_string(to_value(shapes)) << std::endl;
        });
Paul's avatar
Paul committed
193
194
195
196
197
198
    }
};

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx