fuse_mlir.cpp 15.8 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/mlir.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
30
#include <migraphx/env.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
31
32
33
34
35
36
37
38

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module;

namespace gpu {

39
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR);
40

41
42
43
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
44
45
    const bool mlir_disabled = enabled(MIGRAPHX_DISABLE_MLIR{});
    return not mlir_disabled;
46
47
48
49
50
#else
    return false;
#endif
}

Paul Fultz II's avatar
Paul Fultz II committed
51
#ifdef MIGRAPHX_MLIR
52
53

struct mlir_op
Paul Fultz II's avatar
Paul Fultz II committed
54
{
55
    std::string name() const { return "gpu::mlir_op"; }
Paul Fultz II's avatar
Paul Fultz II committed
56
57
58
59
60
61
62
63
64
65
    operation op = make_op("convolution");

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.op, "op"));
    }

    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
66
        check_shapes{inputs, *this}.packed_or_broadcasted();
Paul Fultz II's avatar
Paul Fultz II committed
67
68
69
70
        if(mods.size() != 1)
            MIGRAPHX_THROW("should have one submodule.");
        if(inputs.size() < 2)
            MIGRAPHX_THROW("should have at least two inputs.");
71
72
73
74
75
76
77

        module_ref mod = mods[0];
        auto type      = mod->get_output_shapes().front().type();
        std::unordered_map<instruction_ref, shape> ins_shapes;
        size_t param_cnt               = 0;
        std::vector<std::string> names = mod->get_parameter_names();
        std::sort(names.begin(), names.end());
78
        for(const std::string& param_name : names)
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        {
            ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
        }
        for(auto ins : iterator_for(*mod))
        {
            if(ins->name() == "@param")
            {
                continue;
            }
            if(ins->name() == "@literal")
            {
                ins_shapes[ins] = ins->get_shape();
                continue;
            }
            if(ins->name() == "@return")
            {
95
96
97
98
                auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
                if(not s.standard())
                    MIGRAPHX_THROW("MLIR doesnt support non-standard output");
                return s;
99
100
101
102
103
104
105
106
107
108
            }
            std::vector<shape> input_shapes;
            input_shapes.resize(ins->inputs().size());
            std::transform(ins->inputs().begin(),
                           ins->inputs().end(),
                           input_shapes.begin(),
                           [&](auto in) { return ins_shapes[in]; });
            ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes);
        }
        MIGRAPHX_THROW("No return found in the submodule");
Paul Fultz II's avatar
Paul Fultz II committed
109
110
    }
};
111
MIGRAPHX_REGISTER_OP(mlir_op);
Paul Fultz II's avatar
Paul Fultz II committed
112
113

namespace {
114
115
116
117
118
119
120
121
122
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{
    std::vector<instruction_ref> top_inputs;
    std::vector<instruction_ref> imm_inputs;
    size_t input_cnt = 0;
    for(instruction_ref input : gemm_based_op->inputs())
    {
        std::vector<operation> op_stream;
123
124
125
        while(contains(
            {"slice", "transpose", "contiguous", "reshape", "squeeze", "flatten", "unsqueeze"},
            input->name()))
126
        {
127
128
129
130
131
132
            operation op = input->get_operator();
            if(contains({"squeeze", "flatten", "unsqueeze"}, input->name()))
            {
                op = migraphx::make_op("reshape", {{"dims", input->get_shape().lens()}});
            }
            op_stream.push_back(op);
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
            input = input->inputs().at(0);
        }
        top_inputs.push_back(input);
        instruction_ref prev_input =
            mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
        for(const auto& op : reverse(op_stream))
        {
            prev_input = mm->add_instruction(op, {prev_input});
        }
        imm_inputs.push_back(prev_input);
    }
    instruction_ref new_gemm_based_op =
        mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
    return {new_gemm_based_op, top_inputs};
}
148

149
enum class mlir_mode
150
{
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    all,
    fast,
    int8,
    none
};

auto is_mlir_dot(mlir_mode mode)
{
    return match::make_basic_pred_matcher([=](instruction_ref ins) {
        if(mode == mlir_mode::none)
            return false;
        if(ins->name() != "dot" and ins->name() != "quant_dot")
            return false;
        if(mode != mlir_mode::fast)
            return true;
Paul's avatar
Format  
Paul committed
166
167
        auto a  = ins->inputs().front()->get_shape();
        auto b  = ins->inputs().back()->get_shape();
Paul's avatar
Paul committed
168
169
170
        float m = a.lens()[a.lens().size() - 2];
        float n = b.lens().back();
        float k = a.lens().back();
Paul's avatar
Format  
Paul committed
171
        float g = a.elements() / (m * k);
Paul's avatar
Format  
Paul committed
172
        if(k > 1024)
Paul's avatar
Paul committed
173
            return false;
Paul's avatar
Paul committed
174
        auto ratio = std::sqrt(g) * m * n / k;
Paul's avatar
Paul committed
175
        if(ratio < 2048)
Paul's avatar
Paul committed
176
177
            return false;
        return true;
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    });
}

auto is_mlir_conv(mlir_mode mode)
{
    return match::make_basic_pred_matcher([=](instruction_ref ins) {
        if(mode == mlir_mode::none)
            return false;
        if(ins->name() != "convolution" and ins->name() != "quant_convolution")
            return false;
        value v    = ins->get_operator().to_value();
        auto group = v.at("group").to<int>();
        if(group != 1)
            return false;
        // Avoid MLIR assertion: Index < Length && "Invalid index!"
        if(ins->get_shape().lens().size() != 4)
            return false;
        if(ins->get_shape().type() == shape::int8_type)
            return true;
        if(mode == mlir_mode::int8)
            return false;
        if(mode == mlir_mode::all)
            return true;
        auto w = ins->inputs().at(1)->get_shape();
        if(w.lens().size() != 4)
            return true;
        if(w.lens()[2] != w.lens()[3])
            return true;
        return (w.lens()[3] % 3) != 0;
    });
208
209
}

210
struct find_mlir_fused_ops
Paul Fultz II's avatar
Paul Fultz II committed
211
{
212
213
    mlir_mode conv_mode = mlir_mode::none;
    mlir_mode dot_mode  = mlir_mode::none;
Paul Fultz II's avatar
Paul Fultz II committed
214
215
    auto matcher() const
    {
216
        auto dot_or_conv = match::skip(match::name("contiguous"))(
217
            match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
218
        return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
Paul Fultz II's avatar
Paul Fultz II committed
219
220
    }

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    std::unordered_map<instruction_ref, instruction_ref>
    create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
    {
        std::unordered_map<instruction_ref, instruction_ref> ins_map;
        for(auto ins : iterator_for(*pm))
        {
            if(ins->name() != "@literal")
            {
                continue;
            }
            literal r               = ins->get_literal();
            instruction_ref literal = mm->add_literal(r);
            instruction_ref mbcast  = mm->add_instruction(
                make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
            ins_map[ins] = mbcast;
        }
        return ins_map;
    }

240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    // Whitelist supported fusion options, including imposing type constraints
    // for cases where MLIR only supports an operation (usually a pointwise function)
    // on particular types.
    bool is_pointwise_op_supported_by_mlir(const instruction& i) const
    {
        using type_t                                      = shape::type_t;
        const auto& name                                  = i.name();
        const auto result_type                            = i.get_shape().type();
        const std::initializer_list<type_t> allowed_types = {type_t::float_type,
                                                             type_t::half_type,
                                                             type_t::int8_type,
                                                             type_t::int32_type,
                                                             type_t::bool_type};
        // Preliminary type check.
        if(not contains(allowed_types, result_type))
        {
            return false;
        }
        const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        const std::initializer_list<std::string> no_bool_ops  = {
            "convolution",
            "quant_convolution",
            "dot",
            "quant_dot",
            "add",
            "clip",
            "relu",
            "sub",
            "mul",
            "div",
            "pow",
            "where",
            "quantizelinear",
            "dequantizelinear",
            "abs",
            "neg",
        };
        const std::initializer_list<std::string> fp_only_ops = {
            "ceil",
            "erf",
            "exp",
            "floor",
            "log",
            "recip",
            "rsqrt",
285
            "sigmoid",
286
287
288
            "softmax",
            "tanh",
        };
289
290
291
        bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
        if(contains(any_type_ops, name))
            return true;
292
        if(result_type != type_t::bool_type and contains(no_bool_ops, name))
293
            return true;
294
        if(is_float and contains(fp_only_ops, name))
295
296
297
            return true;
        // Only conversions between floating types are known to be unambigiously
        // supported.
298
        if(is_float and name == "convert")
299
300
301
302
303
304
305
306
        {
            return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
                return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
            });
        }
        return false;
    }

Paul Fultz II's avatar
Paul Fultz II committed
307
308
    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
309
310
311
312
313
        auto ins           = r.result;
        auto gemm_based_op = r.instructions["gemm_based_op"];
        auto x_ins         = r.instructions["x"]; // input after contiguous
        auto* pm           = ins->module_inputs().front();
        auto names         = pm->get_parameter_names();
314
315
316
        // Whitelist pointwise operators.
        if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
               return not is_pointwise_op_supported_by_mlir(i);
Paul Fultz II's avatar
Paul Fultz II committed
317
318
           }))
            return;
319

Paul Fultz II's avatar
Paul Fultz II committed
320
321
322
        std::sort(names.begin(), names.end());
        module_ref mm = mpm.create_module("mlir_" + pm->name());
        mm->set_bypass();
323
324
325
        std::unordered_map<instruction_ref, instruction_ref> param_map =
            create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
        auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
Paul Fultz II's avatar
Paul Fultz II committed
326
327
328
329
        std::transform(names.begin(),
                       names.end(),
                       ins->inputs().begin(),
                       std::inserter(param_map, param_map.end()),
330
                       [&, &anchor = anchor_op](auto name, auto input) {
Paul Fultz II's avatar
Paul Fultz II committed
331
                           if(input == x_ins)
332
                               return std::make_pair(pm->get_parameter(name), anchor);
Paul Fultz II's avatar
Paul Fultz II committed
333
334
335
336
337
338
339
340
341
                           return std::make_pair(pm->get_parameter(name),
                                                 mm->add_parameter(name, input->get_shape()));
                       });
        mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));

        std::vector<instruction_ref> inputs;
        std::copy_if(ins->inputs().begin(),
                     ins->inputs().end(),
                     std::back_inserter(inputs),
342
                     [&](auto input) { return input != gemm_based_op; });
343
        inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
Paul Fultz II's avatar
Paul Fultz II committed
344
        mpm.get_module().replace_instruction(
345
            ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
Paul Fultz II's avatar
Paul Fultz II committed
346
347
    }
};
348

349
template <auto Matcher>
350
struct find_mlir_standalone_op
351
{
352
353
    mlir_mode mode = mlir_mode::none;
    auto matcher() const { return Matcher(mode); }
354
355
356
357
358
359
360
361
362
363
364
365
    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
        auto conv_based_op = r.result;
        // enable only for fp32/fp16/i8 types
        if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
               return not contains(
                   {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
                   i->get_shape().type());
           }))
            return;

        static size_t counter = 0;
366
367
        module_ref mm =
            mpm.create_module("mlir_" + conv_based_op->name() + std::to_string(counter++));
368
369
370
371
372
373
374
375
        mm->set_bypass();
        auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
        mm->add_return({anchor_op});
        mpm.get_module().replace_instruction(
            conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
    }
};

376
377
using find_mlir_standalone_convolution_op = find_mlir_standalone_op<&is_mlir_conv>;
using find_mlir_standalone_dot_op         = find_mlir_standalone_op<&is_mlir_dot>;
378

379
380
381
382
383
384
385
/**
 * @brief Declares a new MIGraphX environment variable which forces to generate
 * only specific MLIR operations.
 *
 * The variable, if defined, forces MIGraphX to use only specific operations
 * with MLIR regardless of the underlying GPU architecture. The variable accepts
 * a list of operations separated by comma. The variable recognizes the following
386
 * operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
387
388
389
390
391
 * will decide by itself which operations to delegate to MLIR. The variable is
 * intended to be primarily used by rocMLIR developers.
 */
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);

392
bool is_requested(std::string_view option, bool fallback = false)
393
394
{
    auto string_value  = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
395
396
    if(string_value.empty())
        return fallback;
397
398
399
    const auto options = split_string(string_value, ',');
    return contains(options, option);
}
Paul Fultz II's avatar
Paul Fultz II committed
400
401
} // namespace

402
#endif // MIGRAPHX_MLIR
Paul Fultz II's avatar
Paul Fultz II committed
403
404
405
406

void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
407
408
    const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
    const bool is_navi      = starts_with(device_name, "gfx110");
409

410
411
412
413
414
415
416
    auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
        if(is_requested(option))
            return mlir_mode::all;
        if(is_navi)
            return mlir_mode::all;
        return std::max(m1, m2);
    };
417

418
419
    match::find_matches(mpm,
                        find_mlir_fused_ops{.conv_mode = get_mode("fused", mlir_mode::fast),
Paul's avatar
Paul committed
420
                                            .dot_mode  = get_mode("fused", mlir_mode::fast)});
421
422
423
424
425

    match::find_matches(
        mpm,
        find_mlir_standalone_convolution_op{get_mode("convolution", mlir_mode::int8)},
        find_mlir_standalone_dot_op{get_mode("dot", mlir_mode::none)});
Paul Fultz II's avatar
Paul Fultz II committed
426
427
428
429
430
431
432
433
#else
    (void)mpm;
#endif
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx