"configs/models/wizardlm/hf_wizardmath_7b_v1_1.py" did not exist on "32f40a8f83de02f58f8f7eaaf37a8ab5a18dc77d"
fuse_mlir.cpp 14.5 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/mlir.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
30
#include <migraphx/env.hpp>
Paul Fultz II's avatar
Paul Fultz II committed
31
32
33
34
35
36
37
38

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct module;

namespace gpu {

39
40
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
    const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
    if(mlir_enabled)
    {
        return true;
    }
    else
    {

        std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
                     "var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
                  << std::endl;
        return false;
    }
#else
    return false;
#endif
}

Paul Fultz II's avatar
Paul Fultz II committed
62
#ifdef MIGRAPHX_MLIR
63
64

struct mlir_op
Paul Fultz II's avatar
Paul Fultz II committed
65
{
66
    std::string name() const { return "gpu::mlir_op"; }
Paul Fultz II's avatar
Paul Fultz II committed
67
68
69
70
71
72
73
74
75
76
    operation op = make_op("convolution");

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.op, "op"));
    }

    shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
    {
77
        check_shapes{inputs, *this}.packed_or_broadcasted();
Paul Fultz II's avatar
Paul Fultz II committed
78
79
80
81
        if(mods.size() != 1)
            MIGRAPHX_THROW("should have one submodule.");
        if(inputs.size() < 2)
            MIGRAPHX_THROW("should have at least two inputs.");
82
83
84
85
86
87
88

        module_ref mod = mods[0];
        auto type      = mod->get_output_shapes().front().type();
        std::unordered_map<instruction_ref, shape> ins_shapes;
        size_t param_cnt               = 0;
        std::vector<std::string> names = mod->get_parameter_names();
        std::sort(names.begin(), names.end());
89
        for(const std::string& param_name : names)
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        {
            ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
        }
        for(auto ins : iterator_for(*mod))
        {
            if(ins->name() == "@param")
            {
                continue;
            }
            if(ins->name() == "@literal")
            {
                ins_shapes[ins] = ins->get_shape();
                continue;
            }
            if(ins->name() == "@return")
            {
106
107
108
109
                auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
                if(not s.standard())
                    MIGRAPHX_THROW("MLIR doesnt support non-standard output");
                return s;
110
111
112
113
114
115
116
117
118
119
            }
            std::vector<shape> input_shapes;
            input_shapes.resize(ins->inputs().size());
            std::transform(ins->inputs().begin(),
                           ins->inputs().end(),
                           input_shapes.begin(),
                           [&](auto in) { return ins_shapes[in]; });
            ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes);
        }
        MIGRAPHX_THROW("No return found in the submodule");
Paul Fultz II's avatar
Paul Fultz II committed
120
121
    }
};
122
MIGRAPHX_REGISTER_OP(mlir_op);
Paul Fultz II's avatar
Paul Fultz II committed
123
124

namespace {
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{
    std::vector<instruction_ref> top_inputs;
    std::vector<instruction_ref> imm_inputs;
    size_t input_cnt = 0;
    for(instruction_ref input : gemm_based_op->inputs())
    {
        std::vector<operation> op_stream;
        while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
        {
            op_stream.push_back(input->get_operator());
            input = input->inputs().at(0);
        }
        top_inputs.push_back(input);
        instruction_ref prev_input =
            mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
        for(const auto& op : reverse(op_stream))
        {
            prev_input = mm->add_instruction(op, {prev_input});
        }
        imm_inputs.push_back(prev_input);
    }
    instruction_ref new_gemm_based_op =
        mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
    return {new_gemm_based_op, top_inputs};
}
152
153
154

MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{
155
    if(ins->name() != "convolution" and ins->name() != "quant_convolution")
156
157
158
159
160
        return false;
    value v    = ins->get_operator().to_value();
    auto group = v.at("group").to<int>();
    if(group != 1)
        return false;
161
162
163
    // Avoid MLIR assertion: Index < Length && "Invalid index!"
    if(ins->get_shape().lens().size() != 4)
        return false;
164
165
166
    return true;
}

167
struct find_mlir_fused_ops
Paul Fultz II's avatar
Paul Fultz II committed
168
169
170
{
    auto matcher() const
    {
171
        auto dot_or_conv = match::skip(match::name("contiguous"))(
172
173
            match::any_of(match::name("dot"), match::name("quant_dot"), is_mlir_conv())
                .bind("gemm_based_op"));
174
        return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
Paul Fultz II's avatar
Paul Fultz II committed
175
176
    }

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    std::unordered_map<instruction_ref, instruction_ref>
    create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
    {
        std::unordered_map<instruction_ref, instruction_ref> ins_map;
        for(auto ins : iterator_for(*pm))
        {
            if(ins->name() != "@literal")
            {
                continue;
            }
            literal r               = ins->get_literal();
            instruction_ref literal = mm->add_literal(r);
            instruction_ref mbcast  = mm->add_instruction(
                make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
            ins_map[ins] = mbcast;
        }
        return ins_map;
    }

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    // Whitelist supported fusion options, including imposing type constraints
    // for cases where MLIR only supports an operation (usually a pointwise function)
    // on particular types.
    bool is_pointwise_op_supported_by_mlir(const instruction& i) const
    {
        using type_t                                      = shape::type_t;
        const auto& name                                  = i.name();
        const auto result_type                            = i.get_shape().type();
        const std::initializer_list<type_t> allowed_types = {type_t::float_type,
                                                             type_t::half_type,
                                                             type_t::int8_type,
                                                             type_t::int32_type,
                                                             type_t::bool_type};
        // Preliminary type check.
        if(not contains(allowed_types, result_type))
        {
            return false;
        }
        const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        const std::initializer_list<std::string> no_bool_ops  = {
            "convolution",
            "quant_convolution",
            "dot",
            "quant_dot",
            "add",
            "clip",
            "relu",
            "sub",
            "mul",
            "div",
            "pow",
            "where",
            "quantizelinear",
            "dequantizelinear",
            "abs",
            "neg",
        };
        const std::initializer_list<std::string> fp_only_ops = {
            "ceil",
            "erf",
            "exp",
            "floor",
            "log",
            "recip",
            "rsqrt",
241
            "sigmoid",
242
243
244
            "softmax",
            "tanh",
        };
245
246
247
        bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
        if(contains(any_type_ops, name))
            return true;
248
        if(result_type != type_t::bool_type and contains(no_bool_ops, name))
249
            return true;
250
        if(is_float and contains(fp_only_ops, name))
251
252
253
            return true;
        // Only conversions between floating types are known to be unambigiously
        // supported.
254
        if(is_float and name == "convert")
255
256
257
258
259
260
261
262
        {
            return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
                return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
            });
        }
        return false;
    }

Paul Fultz II's avatar
Paul Fultz II committed
263
264
    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
265
266
267
268
269
        auto ins           = r.result;
        auto gemm_based_op = r.instructions["gemm_based_op"];
        auto x_ins         = r.instructions["x"]; // input after contiguous
        auto* pm           = ins->module_inputs().front();
        auto names         = pm->get_parameter_names();
270
271
272
        // Whitelist pointwise operators.
        if(std::any_of(pm->begin(), pm->end(), [&](const auto& i) {
               return not is_pointwise_op_supported_by_mlir(i);
Paul Fultz II's avatar
Paul Fultz II committed
273
274
           }))
            return;
275

Paul Fultz II's avatar
Paul Fultz II committed
276
277
278
        std::sort(names.begin(), names.end());
        module_ref mm = mpm.create_module("mlir_" + pm->name());
        mm->set_bypass();
279
280
281
        std::unordered_map<instruction_ref, instruction_ref> param_map =
            create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
        auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
Paul Fultz II's avatar
Paul Fultz II committed
282
283
284
285
        std::transform(names.begin(),
                       names.end(),
                       ins->inputs().begin(),
                       std::inserter(param_map, param_map.end()),
286
                       [&, &anchor_op = anchor_op](auto name, auto input) {
Paul Fultz II's avatar
Paul Fultz II committed
287
                           if(input == x_ins)
288
                               return std::make_pair(pm->get_parameter(name), anchor_op);
Paul Fultz II's avatar
Paul Fultz II committed
289
290
291
292
293
294
295
296
297
                           return std::make_pair(pm->get_parameter(name),
                                                 mm->add_parameter(name, input->get_shape()));
                       });
        mm->add_return(mm->insert_instructions(mm->end(), pm, param_map));

        std::vector<instruction_ref> inputs;
        std::copy_if(ins->inputs().begin(),
                     ins->inputs().end(),
                     std::back_inserter(inputs),
298
                     [&](auto input) { return input != gemm_based_op; });
299
        inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
Paul Fultz II's avatar
Paul Fultz II committed
300
        mpm.get_module().replace_instruction(
301
            ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
Paul Fultz II's avatar
Paul Fultz II committed
302
303
    }
};
304

305
struct find_mlir_standalone_op
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
{
    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
        auto conv_based_op = r.result;
        // enable only for fp32/fp16/i8 types
        if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
               return not contains(
                   {shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
                   i->get_shape().type());
           }))
            return;

        static size_t counter = 0;
        module_ref mm         = mpm.create_module("mlir_" + std::to_string(counter++));
        mm->set_bypass();
        auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
        mm->add_return({anchor_op});
        mpm.get_module().replace_instruction(
            conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
    }
};

328
329
330
331
332
333
334
335
336
337
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{
    auto matcher() const { return match::name("convolution"); }
};

struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
    auto matcher() const { return match::name("dot"); }
};

338
339
340
341
342
343
344
/**
 * @brief Declares a new MIGraphX environment variable which forces to generate
 * only specific MLIR operations.
 *
 * The variable, if defined, forces MIGraphX to use only specific operations
 * with MLIR regardless of the underlying GPU architecture. The variable accepts
 * a list of operations separated by comma. The variable recognizes the following
345
 * operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
346
347
348
349
350
351
352
353
354
355
356
357
358
359
 * will decide by itself which operations to delegate to MLIR. The variable is
 * intended to be primarily used by rocMLIR developers.
 */
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }

bool is_requested(std::string_view option)
{
    assert(not is_self_decide());
    auto string_value  = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
    const auto options = split_string(string_value, ',');
    return contains(options, option);
}

360
bool is_enabled(std::string_view op_name, context* ctx)
361
362
363
{
    if(is_self_decide())
    {
364
        if(op_name == "fused")
365
        {
366
367
368
369
370
371
372
373
374
375
376
377
378
379
            return true;
        }
        else if(op_name == "convolution")
        {
            if(ctx == nullptr)
            {
                return false;
            }
            else
            {
                const auto& device = ctx->get_current_device();
                const std::string navi_family{"gfx110"};
                return starts_with(device.get_gfx_name(), navi_family);
            }
380
381
382
        }
        else
        {
383
            return false;
384
385
        }
    }
386
    return is_requested(op_name);
387
}
Paul Fultz II's avatar
Paul Fultz II committed
388
389
} // namespace

390
#endif // MIGRAPHX_MLIR
Paul Fultz II's avatar
Paul Fultz II committed
391
392
393
394

void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
395
    if(is_enabled("fused", this->ctx))
396
397
398
399
    {
        match::find_matches(mpm, find_mlir_fused_ops{});
    }

400
    if(is_enabled("convolution", this->ctx))
401
402
403
    {
        match::find_matches(mpm, find_mlir_standalone_convolution_op{});
    }
404
405
406
407
408

    if(is_enabled("dot", this->ctx))
    {
        match::find_matches(mpm, find_mlir_standalone_dot_op{});
    }
Paul Fultz II's avatar
Paul Fultz II committed
409
410
411
412
413
414
415
416
#else
    (void)mpm;
#endif
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx