fuse_ops.cpp 26.2 KB
Newer Older
kahmed10's avatar
kahmed10 committed
1
2
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
Paul's avatar
Paul committed
3
4
5
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
kahmed10's avatar
kahmed10 committed
6
#include <migraphx/gpu/clip.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/gpu/convolution.hpp>
8
#include <migraphx/gpu/oper.hpp>
kahmed10's avatar
kahmed10 committed
9
10
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/mul.hpp>
11
#include <migraphx/gpu/gemm.hpp>
kahmed10's avatar
kahmed10 committed
12
#include <migraphx/gpu/device/layernorm.hpp>
kahmed10's avatar
kahmed10 committed
13
#include <migraphx/gpu/device/gelu.hpp>
Paul's avatar
Paul committed
14
#include <migraphx/gpu/device/mul_add.hpp>
15
16
17
18
19
#include <migraphx/gpu/device/add_clip.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_sigmoid.hpp>
#include <migraphx/gpu/device/add_tanh.hpp>
#include <migraphx/gpu/device/mul_add_relu.hpp>
Paul's avatar
Paul committed
20
#include <migraphx/gpu/device/add.hpp>
Paul's avatar
Paul committed
21
#include <migraphx/instruction.hpp>
22
#include <migraphx/register_op.hpp>
Paul's avatar
Paul committed
23
#include <migraphx/array.hpp>
kahmed10's avatar
kahmed10 committed
24
#include <migraphx/op/clip.hpp>
kahmed10's avatar
kahmed10 committed
25
#include <cmath>
Paul's avatar
Paul committed
26
27

namespace migraphx {
Paul's avatar
Paul committed
28
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
29
30
namespace gpu {

31
32
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION)

Paul's avatar
Paul committed
33
34
35
36
37
38
39
40
struct fusion
{
    using op_t = miopenFusionOpDescriptor_t;
    shared<fusion_plan_descriptor> fp;

    // Used as a temporary hack to keep descriptor references alive
    std::vector<std::shared_ptr<void>> storage;

Paul's avatar
Paul committed
41
    template <class T>
Paul's avatar
Paul committed
42
43
44
45
46
47
48
    auto keep_alive(T x)
    {
        auto result = share(std::move(x));
        storage.push_back(result);
        return result;
    }

49
50
    fusion() = default;

Paul's avatar
Paul committed
51
52
    fusion(const shape& input)
    {
53
        assert(input.standard());
Paul's avatar
Paul committed
54
        auto t = make_tensor(input);
Paul's avatar
Paul committed
55
        fp     = make_fusion_plan(t);
56
        assert(fp);
Paul's avatar
Paul committed
57
58
59
60
61
        keep_alive(std::move(t));
    }

    op_t operator[](std::size_t i) const
    {
62
        assert(fp);
Paul's avatar
Paul committed
63
64
65
        op_t result;
        auto status = miopenFusionPlanGetOp(fp.get(), i, &result);
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
66
            MIGRAPHX_THROW("Failed retrieving operator at " + std::to_string(i));
Paul's avatar
Paul committed
67
68
69
        return result;
    }

70
71
72
73
74
    auto get() const
    {
        assert(fp);
        return fp.get();
    }
Paul's avatar
Paul committed
75
76
77

    op_t create_bias(const shape& bias)
    {
78
        assert(fp);
Paul's avatar
Paul committed
79
        op_t result;
Paul's avatar
Paul committed
80
81
        auto b      = shape{bias.type(), {1, bias.lens().at(1), 1, 1}};
        auto t      = keep_alive(make_tensor(b));
Paul's avatar
Paul committed
82
83
        auto status = miopenCreateOpBiasForward(fp.get(), &result, t.get());
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
84
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
85
86
87
88
89
        return result;
    }

    op_t create_relu()
    {
90
        assert(fp);
Paul's avatar
Paul committed
91
92
93
        op_t result;
        auto status = miopenCreateOpActivationForward(fp.get(), &result, miopenActivationRELU);
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
94
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
95
96
97
98
99
        return result;
    }

    op_t create_conv(const op::convolution& op, const shape& weights)
    {
100
        assert(fp);
Paul's avatar
Paul committed
101
        op_t result;
Paul's avatar
Paul committed
102
103
        auto cd     = keep_alive(make_conv(op));
        auto t      = keep_alive(make_tensor(weights));
Paul's avatar
Paul committed
104
105
        auto status = miopenCreateOpConvForward(fp.get(), &result, cd.get(), t.get());
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
106
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
107
108
        return result;
    }
Paul's avatar
Paul committed
109
110
111

    shape get_workspace(context&)
    {
112
        // assert(fp);
Paul's avatar
Paul committed
113
114
115
116
117
        // TODO: Use zero workspace for now
        std::size_t ws_size = 0;
        // int algo_count = 1;
        // miopenConvFwdAlgorithm_t algo;
        // miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo);
Paul's avatar
Paul committed
118
119
        // miopenFusionPlanGetWorkSpaceSize(ctx.get_stream().get_miopen(), fp.get(), &ws_size,
        // algo);
Paul's avatar
Paul committed
120
121
122
123
124
        return shape{shape::int8_type, {ws_size}};
    }

    void compile(context& ctx)
    {
125
        assert(fp);
Paul's avatar
Paul committed
126
        auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get());
Paul's avatar
Paul committed
127
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
128
            MIGRAPHX_THROW("Compiling fusion plan failed");
Paul's avatar
Paul committed
129
130
    }

Paul's avatar
Paul committed
131
132
133
134
    argument execute(context& ctx,
                     const fused_operator_args& fargs,
                     const argument& x,
                     const argument& y) const
Paul's avatar
Paul committed
135
    {
136
        assert(fp);
Paul's avatar
Paul committed
137
138
        auto x_td   = make_tensor(x.get_shape());
        auto y_td   = make_tensor(y.get_shape());
Paul's avatar
Paul committed
139
        auto status = miopenExecuteFusionPlan(ctx.get_stream().get_miopen(),
Paul's avatar
Paul committed
140
141
142
143
144
145
                                              fp.get(),
                                              x_td.get(),
                                              x.implicit(),
                                              y_td.get(),
                                              y.implicit(),
                                              fargs.get());
Paul's avatar
Paul committed
146
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
147
            MIGRAPHX_THROW("Failed to execute fusion plan");
Paul's avatar
Paul committed
148
149
        return y;
    }
Paul's avatar
Paul committed
150
151
};

Paul's avatar
Paul committed
152
MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
Paul's avatar
Paul committed
153
154
{
    auto&& s = ins->get_shape();
Paul's avatar
Paul committed
155
156
    return s.broadcasted() and s.strides().size() == 4 and s.strides()[0] == 0 and
           s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
Paul's avatar
Paul committed
157
158
}

Paul's avatar
Paul committed
159
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
Paul's avatar
Paul committed
160
{
161
162
    if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{}))
        return false;
Paul's avatar
Paul committed
163
164
    if(ins->name() != "gpu::convolution")
        return false;
Paul's avatar
Paul committed
165
166
    if(ins->get_shape().type() != shape::float_type)
        return false;
Paul's avatar
Paul committed
167
168
169
    auto wei = ins->inputs().at(1)->get_shape();
    assert(wei.lens().size() == 4);
    auto conv = any_cast<miopen_convolution>(ins->get_operator());
Khalique's avatar
Khalique committed
170
    if(conv.op.group > 1)
Khalique's avatar
Khalique committed
171
        return false;
Paul's avatar
Paul committed
172
    if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
Paul's avatar
Paul committed
173
        return false;
174
175
176
177
178
179

    // Do not fuse non-symmetric input
    auto input_lens = ins->inputs().at(0)->get_shape().lens();
    if(input_lens[2] != input_lens[3] or wei.lens()[2] != wei.lens()[3])
        return false;

Paul's avatar
Paul committed
180
    auto op = conv.op;
181
182
    // Dont fuse winograd for non-3x3s since there is no fused windograd for those configs
    if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and
183
       wei.lens()[3] != 3 and contains({{1, 1}}, op.stride))
184
        return false;
Paul's avatar
Paul committed
185
    return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and
186
           contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation);
Paul's avatar
Paul committed
187
188
}

189
struct hip_triadd : ternary_device<hip_triadd, &device::add>
Paul's avatar
Paul committed
190
191
{
};
192
MIGRAPHX_REGISTER_OP(hip_triadd)
Paul's avatar
Paul committed
193

194
struct hip_triadd_clip : quinary_device<hip_triadd_clip, &device::add_clip>
kahmed10's avatar
kahmed10 committed
195
196
{
};
197
MIGRAPHX_REGISTER_OP(hip_triadd_clip)
kahmed10's avatar
kahmed10 committed
198

199
struct hip_add_clip : quaternary_device<hip_add_clip, &device::add_clip>
kahmed10's avatar
kahmed10 committed
200
201
{
};
202
MIGRAPHX_REGISTER_OP(hip_add_clip)
kahmed10's avatar
kahmed10 committed
203

204
struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu>
Paul's avatar
Paul committed
205
206
{
};
207
MIGRAPHX_REGISTER_OP(hip_triadd_relu)
Paul's avatar
Paul committed
208

209
210
211
struct hip_triadd_sigmoid : ternary_device<hip_triadd_sigmoid, &device::add_sigmoid>
{
};
212
MIGRAPHX_REGISTER_OP(hip_triadd_sigmoid)
213
214
215
216

struct hip_triadd_tanh : ternary_device<hip_triadd_tanh, &device::add_tanh>
{
};
217
MIGRAPHX_REGISTER_OP(hip_triadd_tanh)
218
219
220
221

struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
{
};
222
MIGRAPHX_REGISTER_OP(hip_add_relu)
223
224
225
226

struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid>
{
};
227
MIGRAPHX_REGISTER_OP(hip_add_sigmoid)
228
229

struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh>
Paul's avatar
Paul committed
230
231
{
};
232
MIGRAPHX_REGISTER_OP(hip_add_tanh)
Paul's avatar
Paul committed
233

kahmed10's avatar
kahmed10 committed
234
235
struct hip_layernorm : unary_device<hip_layernorm, &device::layernorm>
{
236
237
    // Empty finalize to skip dimension reduction
    void finalize(context&, const shape&, const std::vector<shape>&) {}
kahmed10's avatar
kahmed10 committed
238
};
239
MIGRAPHX_REGISTER_OP(hip_layernorm)
kahmed10's avatar
kahmed10 committed
240

Paul Fultz II's avatar
Paul Fultz II committed
241
242
243
244
245
246
247
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{
    // Empty finalize to skip dimension reduction
    void finalize(context&, const shape&, const std::vector<shape>&) {}
};
MIGRAPHX_REGISTER_OP(hip_triadd_layernorm)

kahmed10's avatar
kahmed10 committed
248
249
250
struct hip_gelu : unary_device<hip_gelu, &device::gelu>
{
};
251
MIGRAPHX_REGISTER_OP(hip_gelu)
kahmed10's avatar
kahmed10 committed
252
253
254
255

struct hip_add_gelu : binary_device<hip_add_gelu, &device::add_gelu>
{
};
256
MIGRAPHX_REGISTER_OP(hip_add_gelu)
kahmed10's avatar
kahmed10 committed
257
258
259
260

struct hip_gelu_new : unary_device<hip_gelu_new, &device::gelu_new>
{
};
261
MIGRAPHX_REGISTER_OP(hip_gelu_new)
kahmed10's avatar
kahmed10 committed
262
263
264
265

struct hip_add_gelu_new : binary_device<hip_add_gelu_new, &device::add_gelu_new>
{
};
266
MIGRAPHX_REGISTER_OP(hip_add_gelu_new)
kahmed10's avatar
kahmed10 committed
267

268
struct hip_mul_add : ternary_device<hip_mul_add, &device::mul_add>
Paul's avatar
Paul committed
269
270
{
};
271
MIGRAPHX_REGISTER_OP(hip_mul_add)
Paul's avatar
Paul committed
272

273
struct hip_mul_add_relu : ternary_device<hip_mul_add_relu, &device::mul_add_relu>
Paul's avatar
Paul committed
274
275
{
};
276
MIGRAPHX_REGISTER_OP(hip_mul_add_relu)
Paul's avatar
Paul committed
277

Paul's avatar
Paul committed
278
279
280
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
    // Ensure the last arguments is the broadcasted one
Paul's avatar
Paul committed
281
    auto last = std::prev(args.end());
Paul's avatar
Paul committed
282
283
    auto it =
        std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
Paul's avatar
Paul committed
284
285
    if(it != last)
        std::swap(*it, *std::prev(last));
Paul's avatar
Paul committed
286
287
288
289
290
}

void move_standard_front(std::vector<instruction_ref>& args)
{
    // Ensure the first arguments is the standard one
Paul's avatar
Paul committed
291
    auto last = std::prev(args.end());
Paul's avatar
Paul committed
292
293
    auto it =
        std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
Paul's avatar
Paul committed
294
    if(it != last)
Paul's avatar
Paul committed
295
296
297
        std::swap(*it, args.front());
}

kahmed10's avatar
kahmed10 committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
struct find_layernorm
{
    template <class... Ts>
    static auto multibroadcast_op(Ts... xs)
    {
        return match::name("multibroadcast")(match::arg(0)(xs...));
    }

    static auto x_minus_mean()
    {
        return match::name("gpu::sub")(
            match::arg(0)(match::any().bind("x")),
            match::arg(1)(multibroadcast_op(match::name("gpu::reduce_mean"))));
    }

    static auto variance()
    {
        return match::name("gpu::reduce_mean")(match::arg(0)(
            match::name("gpu::pow")(match::arg(0)(x_minus_mean()),
                                    match::arg(1)(multibroadcast_op(match::has_value(2.0f))))));
    }

    static auto layernorm_onnx()
    {
        return match::name("gpu::div")(
            match::arg(0)(x_minus_mean()),

            match::arg(1)(multibroadcast_op(
                match::name("gpu::sqrt")(match::arg(0)(match::name("gpu::add")(match::either_arg(
                    0, 1)(variance(), multibroadcast_op(match::has_value(1e-12f)))))))));
    }

    auto matcher() const { return layernorm_onnx(); }

332
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
333
334
335
336
337
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto args  = ins->inputs();

338
339
340
341
342
343
344
345
346
        // We dont fuse for non-standard layouts
        if(not x_ins->get_shape().standard())
            return;

        auto relements = x_ins->get_shape().lens().back();

        if(relements > 1024 or (relements % 4 != 0 and relements > 256))
            return;

kahmed10's avatar
kahmed10 committed
347
348
349
350
        p.replace_instruction(ins, hip_layernorm{}, x_ins, args.back());
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
struct find_triadd_layernorm
{
    auto matcher() const
    {
        return match::name("gpu::layernorm")(match::arg(0)(match::name("gpu::triadd")(
            match::used_once(), match::all_of[match::inputs()](match::standard_shape()))));
    }

    void apply(program& p, const match::matcher_result& r) const
    {
        auto ins    = r.result;
        auto triadd = ins->inputs().front();
        p.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs());
    }
};

kahmed10's avatar
kahmed10 committed
367
368
369
370
371
372
373
374
375
struct find_gelu
{

    static auto erf_fn()
    {
        return match::name("gpu::erf")(
            match::used_once(),
            match::arg(0)(match::used_once(),
                          match::name("gpu::mul")(match::either_arg(0, 1)(
kahmed10's avatar
kahmed10 committed
376
377
                              match::none_of(match::has_value(M_SQRT1_2, 1e-3)).bind("x"),
                              match::has_value(M_SQRT1_2, 1e-3)))));
kahmed10's avatar
kahmed10 committed
378
379
    }

Paul Fultz II's avatar
Paul Fultz II committed
380
381
382
383
384
385
386
387
388
    static auto add_erf()
    {
        return match::name("gpu::add")(
            match::used_once(),
            match::either_arg(0, 1)(erf_fn(), match::args(match::has_value(1.0f))));
    }

    static auto one_half() { return match::args(match::has_value(0.5f)); }

kahmed10's avatar
kahmed10 committed
389
390
    auto matcher() const
    {
Paul Fultz II's avatar
Paul Fultz II committed
391
        return match::unordered_tree("gpu::mul", one_half(), add_erf(), match::any());
kahmed10's avatar
kahmed10 committed
392
393
    }

394
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto args  = ins->inputs();

        p.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
    }
};

struct find_add_gelu
{
    auto matcher() const
    {
        return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add")));
    }

411
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    {
        auto add_ins = r.instructions["add"];
        auto ins     = r.result;
        auto args    = add_ins->inputs();
        move_standard_front(args);
        move_broadcasted_back(args);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_add_gelu{}, args);
    }
};

struct find_gelu_new
{
kahmed10's avatar
kahmed10 committed
426
    bool fast_math = true;
kahmed10's avatar
kahmed10 committed
427
428
429
430
431
432
433
434
435
436
437
438

    static auto pow_fn()
    {
        return match::name("gpu::pow")(match::used_once(),
                                       match::arg(1)(match::args(match::has_value(3.0f))));
    }

    static auto tanh_fn()
    {
        return match::name("gpu::tanh")(
            match::used_once(),
            match::arg(0)(match::name("gpu::mul")(match::either_arg(0, 1)(
kahmed10's avatar
kahmed10 committed
439
                match::args(match::has_value(sqrt(M_2_PI), 1e-3)),
kahmed10's avatar
kahmed10 committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
                match::name("gpu::add")(
                    match::any_arg(0, 1)(match::name("gpu::mul")(match::either_arg(0, 1)(
                        match::args(match::has_value(0.044715f)), pow_fn()))))))));
    }

    auto matcher() const
    {
        return match::name("gpu::mul")(
            match::used_once(),
            match::either_arg(0, 1)(
                match::any().bind("x"),
                match::name("gpu::add")(match::any_arg(0, 1)(match::name("gpu::mul")(
                    match::either_arg(0, 1)(match::args(match::has_value(0.5f)), tanh_fn()))))));
    }

455
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
456
457
458
459
460
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto args  = ins->inputs();

kahmed10's avatar
kahmed10 committed
461
        if(not fast_math)
kahmed10's avatar
kahmed10 committed
462
463
464
465
466
467
468
469
470
471
472
473
474
            p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back());
        else
            p.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
    }
};

struct find_add_gelu_new
{
    auto matcher() const
    {
        return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add")));
    }

475
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
476
477
478
479
480
481
482
483
484
485
486
487
    {
        auto add_ins = r.instructions["add"];
        auto ins     = r.result;
        auto args    = add_ins->inputs();
        move_standard_front(args);
        move_broadcasted_back(args);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_add_gelu_new{}, args);
    }
};

kahmed10's avatar
kahmed10 committed
488
489
490
491
492
493
struct find_add_clip
{
    auto matcher() const
    {
        return match::name(std::unordered_set<std::string>{"gpu::clip", "gpu::clipped_relu"})(
            match::arg(0)(match::any_of(match::name("gpu::add"),
kahmed10's avatar
kahmed10 committed
494
                                        match::name("gpu::triadd"),
kahmed10's avatar
kahmed10 committed
495
496
497
498
                                        match::any_of[match::inputs()](match::standard_shape()))
                              .bind("add")));
    }

499
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
500
    {
kahmed10's avatar
kahmed10 committed
501
502
503
504
505
506
507
508
509
510
        auto add_ins  = r.instructions["add"];
        auto ins      = r.result;
        auto ins_args = ins->inputs();
        auto add_args = add_ins->inputs();
        move_standard_front(add_args);
        move_broadcasted_back(add_args);

        // Use the allocation from the clip operator
        add_args.pop_back();
        add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end());
kahmed10's avatar
kahmed10 committed
511
        if(add_ins->name() == "gpu::add")
kahmed10's avatar
kahmed10 committed
512
            p.replace_instruction(ins, hip_add_clip{}, add_args);
kahmed10's avatar
kahmed10 committed
513
        else if(add_ins->name() == "gpu::triadd")
kahmed10's avatar
kahmed10 committed
514
            p.replace_instruction(ins, hip_triadd_clip{}, add_args);
kahmed10's avatar
kahmed10 committed
515
516
517
    }
};

518
struct find_add_unary
Paul's avatar
Paul committed
519
{
520
521
522
    std::string op_name;
    operation binary_add_op;
    operation ternary_add_op;
Paul's avatar
Paul committed
523
524
    auto matcher() const
    {
525
        return match::name(op_name)(match::arg(0)(
Paul's avatar
Paul committed
526
            match::used_once(),
Paul's avatar
Paul committed
527
            match::any_of(match::name("gpu::add"),
kahmed10's avatar
kahmed10 committed
528
                          match::name("gpu::triadd"),
Paul's avatar
Paul committed
529
530
531
                          match::any_of(match::name("@literal"),
                                        match::any_of[match::inputs()](match::standard_shape())))
                .bind("add")));
Paul's avatar
Paul committed
532
    }
Paul's avatar
Paul committed
533

534
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
535
    {
Paul's avatar
Paul committed
536
        auto add_ins = r.instructions["add"];
Paul's avatar
Paul committed
537
538
        auto ins     = r.result;
        auto args    = add_ins->inputs();
Paul's avatar
Paul committed
539
540
541
        move_standard_front(args);
        move_broadcasted_back(args);

Paul's avatar
Paul committed
542
        // Use the allocation from the relu operator
Paul's avatar
Paul committed
543
        args.back() = ins->inputs().back();
Paul's avatar
Paul committed
544
        if(add_ins->name() == "gpu::add")
545
            p.replace_instruction(ins, binary_add_op, args);
kahmed10's avatar
kahmed10 committed
546
        else if(add_ins->name() == "gpu::triadd")
547
            p.replace_instruction(ins, ternary_add_op, args);
Paul's avatar
Paul committed
548
549
550
    }
};

Paul's avatar
Paul committed
551
struct find_triadd
Paul's avatar
Paul committed
552
553
554
{
    auto matcher() const
    {
Paul's avatar
Paul committed
555
        return match::name("gpu::add")(match::either_arg(0, 1)(
Paul's avatar
Paul committed
556
            match::name("gpu::add")(match::used_once()).bind("add"),
Paul's avatar
Paul committed
557
558
559
            match::any(match::any_of(match::name("@literal"),
                                     match::any_of[match::inputs()](match::standard_shape())))
                .bind("input")));
Paul's avatar
Paul committed
560
561
    }

562
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
563
    {
Paul's avatar
Paul committed
564
565
566
567
        auto add_ins   = r.instructions["add"];
        auto input_ins = r.instructions["input"];
        auto ins       = r.result;
        auto args      = add_ins->inputs();
568

Paul's avatar
Paul committed
569
        auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
570
        if(std::count_if(args.begin(), args.end(), is_broadcasted) > 2)
Paul's avatar
Paul committed
571
572
            return;
        args.insert(args.begin(), input_ins);
Paul's avatar
Paul committed
573
574
575
        move_standard_front(args);
        move_broadcasted_back(args);

Paul's avatar
Paul committed
576
577
        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_triadd{}, args);
Paul's avatar
Paul committed
578
    }
Paul's avatar
Paul committed
579
580
};

Paul's avatar
Paul committed
581
582
583
584
struct find_mul_add
{
    auto matcher() const
    {
Paul's avatar
Paul committed
585
586
        return match::name("gpu::add")(match::either_arg(0, 1)(
            match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
Paul's avatar
Paul committed
587
588
    }

589
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
590
    {
Paul's avatar
Paul committed
591
592
593
594
        auto mul_ins = r.instructions["mul"];
        auto b_ins   = r.instructions["b"];
        auto ins     = r.result;
        auto args    = mul_ins->inputs();
Paul's avatar
Paul committed
595
596
597
598
599
600
601
602
603
604
605
        assert(mul_ins != b_ins);

        move_standard_front(args);
        move_broadcasted_back(args);
        args.insert(std::prev(args.end()), b_ins);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_mul_add{}, args);
    }
};

Paul's avatar
Paul committed
606
607
608
609
struct find_mul_add_relu
{
    auto matcher() const
    {
Paul's avatar
Paul committed
610
        return match::name("gpu::relu")(
kahmed10's avatar
kahmed10 committed
611
            match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add")));
Paul's avatar
Paul committed
612
613
    }

614
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
615
616
    {
        auto mul_add_ins = r.instructions["mul_add"];
Paul's avatar
Paul committed
617
618
        auto ins         = r.result;
        auto args        = mul_add_ins->inputs();
Paul's avatar
Paul committed
619
620
621
622
623
624
625

        // Use the allocation from the relu operator
        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_mul_add_relu{}, args);
    }
};

Paul's avatar
Paul committed
626
627
628
struct miopen_conv_bias
{
    op::convolution op;
629
630
631
    fusion f          = {};
    fusion::op_t conv = {};
    fusion::op_t bias = {};
Paul's avatar
Paul committed
632

Paul's avatar
Paul committed
633
634
635
636
637
638
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::convolution::reflect(self.op, f);
    }

Paul's avatar
Paul committed
639
640
641
642
643
644
645
    std::string name() const { return "gpu::conv_bias"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        // TODO: Check slices
        return op.compute_shape({inputs.at(0), inputs.at(1)});
    }
Paul's avatar
Paul committed
646
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
647
    {
Paul's avatar
Paul committed
648
        auto fargs  = make_fused_args();
Paul's avatar
Paul committed
649
        float alpha = 1;
Paul's avatar
Paul committed
650
        float beta  = 0;
Paul's avatar
Paul committed
651
652
        miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
        miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
Paul's avatar
Paul committed
653
        return f.execute(ctx, fargs, args[0], args[4]);
Paul's avatar
Paul committed
654
655
    }

656
657
658
659
660
661
662
663
    void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
    {
        f    = fusion(inputs[0]);
        conv = f.create_conv(op, inputs[1]);
        bias = f.create_bias(inputs[3]);
        f.compile(ctx);
    }

Paul's avatar
Paul committed
664
    shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
Paul's avatar
Paul committed
665
666
667
668
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Paul committed
669
};
670
MIGRAPHX_REGISTER_OP(miopen_conv_bias)
Paul's avatar
Paul committed
671

Paul's avatar
Add cbr  
Paul committed
672
673
674
struct miopen_conv_bias_relu
{
    op::convolution op;
675
676
677
678
    fusion f          = {};
    fusion::op_t conv = {};
    fusion::op_t bias = {};
    fusion::op_t relu = {};
Paul's avatar
Add cbr  
Paul committed
679

Paul's avatar
Paul committed
680
681
682
683
684
685
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::convolution::reflect(self.op, f);
    }

Paul's avatar
Add cbr  
Paul committed
686
687
688
689
690
691
692
    std::string name() const { return "gpu::conv_bias_relu"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        // TODO: Check slices
        return op.compute_shape({inputs.at(0), inputs.at(1)});
    }
Paul's avatar
Paul committed
693
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Paul's avatar
Add cbr  
Paul committed
694
695
    {
        auto fargs  = make_fused_args();
Paul's avatar
Paul committed
696
        float alpha = 1;
Paul's avatar
Paul committed
697
        float beta  = 0;
Paul's avatar
Add cbr  
Paul committed
698
699
        miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
        miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
Paul's avatar
Paul committed
700
701
        miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
        return f.execute(ctx, fargs, args[0], args[4]);
Paul's avatar
Add cbr  
Paul committed
702
    }
703
704
705
706
707
708
709
710
711
    void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
    {
        f    = fusion(inputs[0]);
        conv = f.create_conv(op, inputs[1]);
        bias = f.create_bias(inputs[3]);
        relu = f.create_relu();
        f.compile(ctx);
    }

Paul's avatar
Paul committed
712
    shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
Paul's avatar
Paul committed
713
714
715
716
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Add cbr  
Paul committed
717
};
718
MIGRAPHX_REGISTER_OP(miopen_conv_bias_relu)
Paul's avatar
Add cbr  
Paul committed
719

Paul's avatar
Paul committed
720
template <class... Ms>
Paul's avatar
Add cbr  
Paul committed
721
722
auto conv_bias(Ms... ms)
{
Paul's avatar
Paul committed
723
    return match::name("gpu::add")(
Paul's avatar
Paul committed
724
725
        match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
                                fusable_conv(match::used_once()).bind("conv")),
Paul's avatar
Paul committed
726
        ms...);
Paul's avatar
Paul committed
727
728
}

Paul's avatar
Paul committed
729
template <class Op>
730
void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
Paul's avatar
Paul committed
731
732
733
734
735
736
737
738
739
740
{
    auto conv_ins    = r.instructions["conv"];
    auto bias_ins    = r.instructions["bias"];
    auto ins         = r.result;
    auto input_ins   = conv_ins->inputs().at(0);
    auto weights_ins = conv_ins->inputs().at(1);
    auto conv_op     = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
    auto alloc_ins   = ins->inputs().back();
    auto old_ws_ins  = conv_ins->inputs().at(2);

741
    Op cb{conv_op};
Paul's avatar
Paul committed
742
    // TODO: Insert ws allocation
Paul's avatar
Paul committed
743
    auto ws = cb.get_workspace(ctx);
Paul's avatar
Paul committed
744
    (void)ws;
Paul's avatar
Paul committed
745
    p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
Paul's avatar
Add cbr  
Paul committed
746
747
}

Paul's avatar
Paul committed
748
struct find_conv_bias
Paul's avatar
Paul committed
749
{
Paul's avatar
Paul committed
750
    context* ctx = nullptr;
Paul's avatar
Paul committed
751
752
    auto matcher() const
    {
kahmed10's avatar
kahmed10 committed
753
754
        return conv_bias(match::none_of(
            match::output(match::name(std::unordered_set<std::string>{"gpu::relu"}))));
Paul's avatar
Paul committed
755
756
    }

757
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
758
    {
Paul's avatar
Paul committed
759
        apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r));
Paul's avatar
Paul committed
760
761
762
    }
};

Paul's avatar
Paul committed
763
struct find_conv_bias_relu
Paul's avatar
Add cbr  
Paul committed
764
765
{
    context* ctx = nullptr;
Paul's avatar
Paul committed
766
    auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
Paul's avatar
Add cbr  
Paul committed
767

768
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Add cbr  
Paul committed
769
    {
Paul's avatar
Paul committed
770
        apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r));
Paul's avatar
Add cbr  
Paul committed
771
772
773
    }
};

774
775
776
777
778
779
780
781
782
783
struct find_gemm_add
{
    auto matcher() const
    {
        return match::name("gpu::add")(
            match::all_of[match::inputs()](match::standard_shape()),
            match::either_arg(0, 1)(match::used_once().bind("c"),
                                    match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
    }

784
    void apply(module& p, match::matcher_result r) const
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
    {
        auto ins      = r.result;
        auto gemm_ins = r.instructions["gemm"];
        auto c_ins    = r.instructions["c"];

        auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());

        // Already fused gemm
        if(not float_equal(gemm.op.beta, 0))
            return;

        if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
               return not i->get_shape().standard();
           }))
            return;

        auto inputs = gemm_ins->inputs();
        inputs.pop_back();

        auto copy_ins = c_ins;

        // Insert copy
        if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
        {
            copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
        }
        inputs.push_back(copy_ins);
        inputs.push_back(copy_ins);

        gemm.op.beta = 1;
        p.replace_instruction(ins, gemm, inputs);
    }
};

struct find_commutative_broadcast
{
    auto matcher() const
    {
        return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape()));
    }

826
    void apply(module& p, const match::matcher_result& r) const
827
828
829
830
831
832
833
834
835
    {
        auto ins  = r.result;
        auto args = ins->inputs();
        move_broadcasted_back(args);

        p.replace_instruction(ins, ins->get_operator(), args);
    }
};

836
void fuse_ops::apply(module& p) const
Paul's avatar
Paul committed
837
{
kahmed10's avatar
kahmed10 committed
838
    match::find_matches(p, find_gelu{}, find_gelu_new{fast_math});
kahmed10's avatar
kahmed10 committed
839
    run_passes(p, {dead_code_elimination{}});
Paul's avatar
Paul committed
840
    match::find_matches(p, find_triadd{});
841
    match::find_matches(p,
kahmed10's avatar
kahmed10 committed
842
                        find_layernorm{},
843
844
845
846
847
848
849
850
851
852
                        find_conv_bias_relu{ctx},
                        find_conv_bias{ctx},
                        find_add_gelu{},
                        find_add_gelu_new{},
                        find_mul_add{},
                        find_mul_add_relu{},
                        find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}},
                        find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
                        find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
                        find_add_clip{});
Paul Fultz II's avatar
Paul Fultz II committed
853
854
    run_passes(p, {dead_code_elimination{}});
    match::find_matches(p, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{});
Paul's avatar
Paul committed
855
}
Paul's avatar
Paul committed
856
857

} // namespace gpu
Paul's avatar
Paul committed
858
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
859
} // namespace migraphx