fuse_mlir.cpp 7.01 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/mlir.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
30
#include <migraphx/env.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
31
32
33
34
35
36
37
38

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module;

namespace gpu {

39
40
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
    const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
    if(mlir_enabled)
    {
        return true;
    }
    else
    {

        std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
                     "var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
                  << std::endl;
        return false;
    }
#else
    return false;
#endif
}

Paul Fultz II's avatar
Paul Fultz II committed
62
#ifdef MIGRAPHX_MLIR
63
64

struct mlir_op
Paul Fultz II's avatar
Paul Fultz II committed
65
{
66
    std::string name() const { return "gpu::mlir_op"; }
Paul Fultz II's avatar
Paul Fultz II committed
67
68
69
70
71
72
73
74
75
76
    operation op = make_op("convolution");

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.op, "op"));
    }

    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
77
        check_shapes{inputs, *this}.packed_or_broadcasted();
Paul Fultz II's avatar
Paul Fultz II committed
78
79
80
81
        if(mods.size() != 1)
            MIGRAPHX_THROW("should have one submodule.");
        if(inputs.size() < 2)
            MIGRAPHX_THROW("should have at least two inputs.");
82
83
84
85
86
        auto n     = inputs.size();
        auto* pm   = mods.front();
        auto type  = pm->get_output_shapes().front().type();
        auto shape = op.compute_shape({inputs[n - 2], inputs[n - 1]});
        return shape.with_type(type);
Paul Fultz II's avatar
Paul Fultz II committed
87
88
    }
};
89
MIGRAPHX_REGISTER_OP(mlir_op);
Paul Fultz II's avatar
Paul Fultz II committed
90
91

namespace {
92
93
94

MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{
95
    if(ins->name() != "convolution" and ins->name() != "quant_convolution")
96
97
98
99
100
        return false;
    value v    = ins->get_operator().to_value();
    auto group = v.at("group").to<int>();
    if(group != 1)
        return false;
101
102
103
    // Avoid MLIR assertion: Index < Length && "Invalid index!"
    if(ins->get_shape().lens().size() != 4)
        return false;
104
105
106
    return true;
}

107
struct find_mlir_op
Paul Fultz II's avatar
Paul Fultz II committed
108
109
110
{
    auto matcher() const
    {
111
112
113
        auto dot_or_conv = match::skip(match::name("contiguous"))(
            match::any_of(match::name("dot"), is_mlir_conv()).bind("gemm_based_op"));
        return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
Paul Fultz II's avatar
Paul Fultz II committed
114
115
116
117
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
118
119
120
121
122
        auto ins           = r.result;
        auto gemm_based_op = r.instructions["gemm_based_op"];
        auto x_ins         = r.instructions["x"]; // input after contiguous
        auto* pm           = ins->module_inputs().front();
        auto names         = pm->get_parameter_names();
Paul Fultz II's avatar
Paul Fultz II committed
123
124
        // Whitelist pointwise operators
        if(std::any_of(pm->begin(), pm->end(), [](const auto& i) {
125
126
127
128
129
130
131
132
133
134
135
               return not contains({"@literal",
                                    "@param",
                                    "@return",
                                    "convolution",
                                    "quant_convolution",
                                    "dot",
                                    "add",
                                    "relu",
                                    "dequantizelinear",
                                    "quantizelinear"},
                                   i.name());
Paul Fultz II's avatar
Paul Fultz II committed
136
137
           }))
            return;
138
        // Only fuse with fp32/fp16/int8/int32
Paul Fultz II's avatar
Paul Fultz II committed
139
        if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
140
141
142
143
               return not contains({shape::type_t::float_type,
                                    shape::type_t::half_type,
                                    shape::type_t::int8_type,
                                    shape::type_t::int32_type},
144
                                   i->get_shape().type());
Paul Fultz II's avatar
Paul Fultz II committed
145
146
147
148
149
150
151
           }))
            return;
        std::sort(names.begin(), names.end());
        module_ref mm = mpm.create_module("mlir_" + pm->name());
        mm->set_bypass();
        std::unordered_map<instruction_ref, instruction_ref> param_map;
        auto x    = mm->add_parameter("x" + std::to_string(names.size()),
152
                                   gemm_based_op->inputs().at(0)->get_shape());
Paul Fultz II's avatar
Paul Fultz II committed
153
        auto w    = mm->add_parameter("x" + std::to_string(names.size() + 1),
154
155
                                   gemm_based_op->inputs().at(1)->get_shape());
        auto conv = mm->add_instruction(gemm_based_op->get_operator(), {x, w});
Paul Fultz II's avatar
Paul Fultz II committed
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        std::transform(names.begin(),
                       names.end(),
                       ins->inputs().begin(),
                       std::inserter(param_map, param_map.end()),
                       [&](auto name, auto input) {
                           if(input == x_ins)
                               return std::make_pair(pm->get_parameter(name), conv);
                           return std::make_pair(pm->get_parameter(name),
                                                 mm->add_parameter(name, input->get_shape()));
                       });
        mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));

        std::vector<instruction_ref> inputs;
        std::copy_if(ins->inputs().begin(),
                     ins->inputs().end(),
                     std::back_inserter(inputs),
172
173
                     [&](auto input) { return input != gemm_based_op; });
        inputs.insert(inputs.end(), gemm_based_op->inputs().begin(), gemm_based_op->inputs().end());
Paul Fultz II's avatar
Paul Fultz II committed
174
        mpm.get_module().replace_instruction(
175
            ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
Paul Fultz II's avatar
Paul Fultz II committed
176
177
    }
};
178

Paul Fultz II's avatar
Paul Fultz II committed
179
180
181
182
183
184
185
} // namespace

#endif

void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
186
    match::find_matches(mpm, find_mlir_op{});
Paul Fultz II's avatar
Paul Fultz II committed
187
188
189
190
191
192
193
194
195
#else
    (void)mpm;
#endif
}

} // namespace gpu

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx