fuse_ops.cpp 23.3 KB
Newer Older
kahmed10's avatar
kahmed10 committed
1
2
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
Paul's avatar
Paul committed
3
4
5
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
kahmed10's avatar
kahmed10 committed
6
#include <migraphx/gpu/clip.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/gpu/convolution.hpp>
8
#include <migraphx/gpu/oper.hpp>
kahmed10's avatar
kahmed10 committed
9
10
11
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/mul.hpp>
#include <migraphx/gpu/device/layernorm.hpp>
kahmed10's avatar
kahmed10 committed
12
#include <migraphx/gpu/device/gelu.hpp>
Paul's avatar
Paul committed
13
#include <migraphx/gpu/device/mul_add.hpp>
14
15
16
17
18
#include <migraphx/gpu/device/add_clip.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_sigmoid.hpp>
#include <migraphx/gpu/device/add_tanh.hpp>
#include <migraphx/gpu/device/mul_add_relu.hpp>
Paul's avatar
Paul committed
19
#include <migraphx/gpu/device/add.hpp>
Paul's avatar
Paul committed
20
#include <migraphx/instruction.hpp>
21
#include <migraphx/register_op.hpp>
Paul's avatar
Paul committed
22
#include <migraphx/array.hpp>
kahmed10's avatar
kahmed10 committed
23
#include <migraphx/op/clip.hpp>
kahmed10's avatar
kahmed10 committed
24
#include <cmath>
Paul's avatar
Paul committed
25
26

namespace migraphx {
Paul's avatar
Paul committed
27
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
28
29
namespace gpu {

30
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION)
kahmed10's avatar
kahmed10 committed
31
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FAST_GELU)
32

Paul's avatar
Paul committed
33
34
35
36
37
38
39
40
struct fusion
{
    using op_t = miopenFusionOpDescriptor_t;
    shared<fusion_plan_descriptor> fp;

    // Used as a temporary hack to keep descriptor references alive
    std::vector<std::shared_ptr<void>> storage;

Paul's avatar
Paul committed
41
    template <class T>
Paul's avatar
Paul committed
42
43
44
45
46
47
48
    auto keep_alive(T x)
    {
        auto result = share(std::move(x));
        storage.push_back(result);
        return result;
    }

49
50
    fusion() = default;

Paul's avatar
Paul committed
51
52
    fusion(const shape& input)
    {
53
        assert(input.standard());
Paul's avatar
Paul committed
54
        auto t = make_tensor(input);
Paul's avatar
Paul committed
55
        fp     = make_fusion_plan(t);
56
        assert(fp);
Paul's avatar
Paul committed
57
58
59
60
61
        keep_alive(std::move(t));
    }

    op_t operator[](std::size_t i) const
    {
62
        assert(fp);
Paul's avatar
Paul committed
63
64
65
        op_t result;
        auto status = miopenFusionPlanGetOp(fp.get(), i, &result);
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
66
            MIGRAPHX_THROW("Failed retrieving operator at " + std::to_string(i));
Paul's avatar
Paul committed
67
68
69
        return result;
    }

70
71
72
73
74
    auto get() const
    {
        assert(fp);
        return fp.get();
    }
Paul's avatar
Paul committed
75
76
77

    op_t create_bias(const shape& bias)
    {
78
        assert(fp);
Paul's avatar
Paul committed
79
        op_t result;
Paul's avatar
Paul committed
80
81
        auto b      = shape{bias.type(), {1, bias.lens().at(1), 1, 1}};
        auto t      = keep_alive(make_tensor(b));
Paul's avatar
Paul committed
82
83
        auto status = miopenCreateOpBiasForward(fp.get(), &result, t.get());
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
84
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
85
86
87
88
89
        return result;
    }

    op_t create_relu()
    {
90
        assert(fp);
Paul's avatar
Paul committed
91
92
93
        op_t result;
        auto status = miopenCreateOpActivationForward(fp.get(), &result, miopenActivationRELU);
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
94
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
95
96
97
98
99
        return result;
    }

    op_t create_conv(const op::convolution& op, const shape& weights)
    {
100
        assert(fp);
Paul's avatar
Paul committed
101
        op_t result;
Paul's avatar
Paul committed
102
103
        auto cd     = keep_alive(make_conv(op));
        auto t      = keep_alive(make_tensor(weights));
Paul's avatar
Paul committed
104
105
        auto status = miopenCreateOpConvForward(fp.get(), &result, cd.get(), t.get());
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
106
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
107
108
        return result;
    }
Paul's avatar
Paul committed
109
110
111

    shape get_workspace(context&)
    {
112
        // assert(fp);
Paul's avatar
Paul committed
113
114
115
116
117
        // TODO: Use zero workspace for now
        std::size_t ws_size = 0;
        // int algo_count = 1;
        // miopenConvFwdAlgorithm_t algo;
        // miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo);
Paul's avatar
Paul committed
118
119
        // miopenFusionPlanGetWorkSpaceSize(ctx.get_stream().get_miopen(), fp.get(), &ws_size,
        // algo);
Paul's avatar
Paul committed
120
121
122
123
124
        return shape{shape::int8_type, {ws_size}};
    }

    void compile(context& ctx)
    {
125
        assert(fp);
Paul's avatar
Paul committed
126
        auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get());
Paul's avatar
Paul committed
127
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
128
            MIGRAPHX_THROW("Compiling fusion plan failed");
Paul's avatar
Paul committed
129
130
    }

Paul's avatar
Paul committed
131
132
133
134
    argument execute(context& ctx,
                     const fused_operator_args& fargs,
                     const argument& x,
                     const argument& y) const
Paul's avatar
Paul committed
135
    {
136
        assert(fp);
Paul's avatar
Paul committed
137
138
        auto x_td   = make_tensor(x.get_shape());
        auto y_td   = make_tensor(y.get_shape());
Paul's avatar
Paul committed
139
        auto status = miopenExecuteFusionPlan(ctx.get_stream().get_miopen(),
Paul's avatar
Paul committed
140
141
142
143
144
145
                                              fp.get(),
                                              x_td.get(),
                                              x.implicit(),
                                              y_td.get(),
                                              y.implicit(),
                                              fargs.get());
Paul's avatar
Paul committed
146
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
147
            MIGRAPHX_THROW("Failed to execute fusion plan");
Paul's avatar
Paul committed
148
149
        return y;
    }
Paul's avatar
Paul committed
150
151
};

Paul's avatar
Paul committed
152
MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
Paul's avatar
Paul committed
153
154
{
    auto&& s = ins->get_shape();
Paul's avatar
Paul committed
155
156
    return s.broadcasted() and s.strides().size() == 4 and s.strides()[0] == 0 and
           s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
Paul's avatar
Paul committed
157
158
}

Paul's avatar
Paul committed
159
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
Paul's avatar
Paul committed
160
{
161
162
    if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{}))
        return false;
Paul's avatar
Paul committed
163
164
    if(ins->name() != "gpu::convolution")
        return false;
Paul's avatar
Paul committed
165
166
    if(ins->get_shape().type() != shape::float_type)
        return false;
Paul's avatar
Paul committed
167
168
169
    auto wei = ins->inputs().at(1)->get_shape();
    assert(wei.lens().size() == 4);
    auto conv = any_cast<miopen_convolution>(ins->get_operator());
Khalique's avatar
Khalique committed
170
    if(conv.op.group > 1)
Khalique's avatar
Khalique committed
171
        return false;
Paul's avatar
Paul committed
172
    if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
Paul's avatar
Paul committed
173
        return false;
174
175
176
177
178
179

    // Do not fuse non-symmetric input
    auto input_lens = ins->inputs().at(0)->get_shape().lens();
    if(input_lens[2] != input_lens[3] or wei.lens()[2] != wei.lens()[3])
        return false;

Paul's avatar
Paul committed
180
    auto op = conv.op;
181
182
    // Dont fuse winograd for non-3x3s since there is no fused windograd for those configs
    if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and
183
       wei.lens()[3] != 3 and contains({{1, 1}}, op.stride))
184
        return false;
Paul's avatar
Paul committed
185
    return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and
186
           contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation);
Paul's avatar
Paul committed
187
188
}

189
struct hip_triadd : ternary_device<hip_triadd, &device::add>
Paul's avatar
Paul committed
190
191
{
};
192
MIGRAPHX_REGISTER_OP(hip_triadd)
Paul's avatar
Paul committed
193

194
struct hip_triadd_clip : quinary_device<hip_triadd_clip, &device::add_clip>
kahmed10's avatar
kahmed10 committed
195
196
{
};
197
MIGRAPHX_REGISTER_OP(hip_triadd_clip)
kahmed10's avatar
kahmed10 committed
198

199
struct hip_add_clip : quaternary_device<hip_add_clip, &device::add_clip>
kahmed10's avatar
kahmed10 committed
200
201
{
};
202
MIGRAPHX_REGISTER_OP(hip_add_clip)
kahmed10's avatar
kahmed10 committed
203

204
struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu>
Paul's avatar
Paul committed
205
206
{
};
207
MIGRAPHX_REGISTER_OP(hip_triadd_relu)
Paul's avatar
Paul committed
208

209
210
211
struct hip_triadd_sigmoid : ternary_device<hip_triadd_sigmoid, &device::add_sigmoid>
{
};
212
MIGRAPHX_REGISTER_OP(hip_triadd_sigmoid)
213
214
215
216

struct hip_triadd_tanh : ternary_device<hip_triadd_tanh, &device::add_tanh>
{
};
217
MIGRAPHX_REGISTER_OP(hip_triadd_tanh)
218
219
220
221

struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
{
};
222
MIGRAPHX_REGISTER_OP(hip_add_relu)
223
224
225
226

struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid>
{
};
227
MIGRAPHX_REGISTER_OP(hip_add_sigmoid)
228
229

struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh>
Paul's avatar
Paul committed
230
231
{
};
232
MIGRAPHX_REGISTER_OP(hip_add_tanh)
Paul's avatar
Paul committed
233

kahmed10's avatar
kahmed10 committed
234
235
236
struct hip_layernorm : unary_device<hip_layernorm, &device::layernorm>
{
};
237
MIGRAPHX_REGISTER_OP(hip_layernorm)
kahmed10's avatar
kahmed10 committed
238

kahmed10's avatar
kahmed10 committed
239
240
241
struct hip_gelu : unary_device<hip_gelu, &device::gelu>
{
};
242
MIGRAPHX_REGISTER_OP(hip_gelu)
kahmed10's avatar
kahmed10 committed
243
244
245
246

struct hip_add_gelu : binary_device<hip_add_gelu, &device::add_gelu>
{
};
247
MIGRAPHX_REGISTER_OP(hip_add_gelu)
kahmed10's avatar
kahmed10 committed
248
249
250
251

struct hip_gelu_new : unary_device<hip_gelu_new, &device::gelu_new>
{
};
252
MIGRAPHX_REGISTER_OP(hip_gelu_new)
kahmed10's avatar
kahmed10 committed
253
254
255
256

struct hip_add_gelu_new : binary_device<hip_add_gelu_new, &device::add_gelu_new>
{
};
257
MIGRAPHX_REGISTER_OP(hip_add_gelu_new)
kahmed10's avatar
kahmed10 committed
258

259
struct hip_mul_add : ternary_device<hip_mul_add, &device::mul_add>
Paul's avatar
Paul committed
260
261
{
};
262
MIGRAPHX_REGISTER_OP(hip_mul_add)
Paul's avatar
Paul committed
263

264
struct hip_mul_add_relu : ternary_device<hip_mul_add_relu, &device::mul_add_relu>
Paul's avatar
Paul committed
265
266
{
};
267
MIGRAPHX_REGISTER_OP(hip_mul_add_relu)
Paul's avatar
Paul committed
268

Paul's avatar
Paul committed
269
270
271
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
    // Ensure the last arguments is the broadcasted one
Paul's avatar
Paul committed
272
    auto last = std::prev(args.end());
Paul's avatar
Paul committed
273
274
    auto it =
        std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
Paul's avatar
Paul committed
275
276
    if(it != last)
        std::swap(*it, *std::prev(last));
Paul's avatar
Paul committed
277
278
279
280
281
}

void move_standard_front(std::vector<instruction_ref>& args)
{
    // Ensure the first arguments is the standard one
Paul's avatar
Paul committed
282
    auto last = std::prev(args.end());
Paul's avatar
Paul committed
283
284
    auto it =
        std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
Paul's avatar
Paul committed
285
    if(it != last)
Paul's avatar
Paul committed
286
287
288
        std::swap(*it, args.front());
}

kahmed10's avatar
kahmed10 committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
struct find_layernorm
{
    template <class... Ts>
    static auto multibroadcast_op(Ts... xs)
    {
        return match::name("multibroadcast")(match::arg(0)(xs...));
    }

    static auto x_minus_mean()
    {
        return match::name("gpu::sub")(
            match::arg(0)(match::any().bind("x")),
            match::arg(1)(multibroadcast_op(match::name("gpu::reduce_mean"))));
    }

    static auto variance()
    {
        return match::name("gpu::reduce_mean")(match::arg(0)(
            match::name("gpu::pow")(match::arg(0)(x_minus_mean()),
                                    match::arg(1)(multibroadcast_op(match::has_value(2.0f))))));
    }

    static auto layernorm_onnx()
    {
        return match::name("gpu::div")(
            match::arg(0)(x_minus_mean()),

            match::arg(1)(multibroadcast_op(
                match::name("gpu::sqrt")(match::arg(0)(match::name("gpu::add")(match::either_arg(
                    0, 1)(variance(), multibroadcast_op(match::has_value(1e-12f)))))))));
    }

    auto matcher() const { return layernorm_onnx(); }

    void apply(program& p, match::matcher_result r) const
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto args  = ins->inputs();

        p.replace_instruction(ins, hip_layernorm{}, x_ins, args.back());
    }
};

kahmed10's avatar
kahmed10 committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
struct find_gelu
{

    static auto erf_fn()
    {
        return match::name("gpu::erf")(
            match::used_once(),
            match::arg(0)(match::used_once(),
                          match::name("gpu::mul")(match::either_arg(0, 1)(
                              match::none_of(match::has_value(M_SQRT1_2)).bind("x"),
                              match::has_value(M_SQRT1_2)))));
    }

    auto matcher() const
    {
        return match::name("gpu::mul")(match::either_arg(0, 1)(
            match::name("gpu::mul")(match::any_arg(0, 1)(match::args(match::has_value(0.5f)))),
            match::name("gpu::add")(
                match::used_once(),
                match::either_arg(0, 1)(erf_fn(), match::args(match::has_value(1.0f))))));
    }

    void apply(program& p, match::matcher_result r) const
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto args  = ins->inputs();

        p.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
    }
};

struct find_add_gelu
{
    auto matcher() const
    {
        return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add")));
    }

    void apply(program& p, match::matcher_result r) const
    {
        auto add_ins = r.instructions["add"];
        auto ins     = r.result;
        auto args    = add_ins->inputs();
        move_standard_front(args);
        move_broadcasted_back(args);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_add_gelu{}, args);
    }
};

struct find_gelu_new
{

    static auto pow_fn()
    {
        return match::name("gpu::pow")(match::used_once(),
                                       match::arg(1)(match::args(match::has_value(3.0f))));
    }

    static auto tanh_fn()
    {
        return match::name("gpu::tanh")(
            match::used_once(),
            match::arg(0)(match::name("gpu::mul")(match::either_arg(0, 1)(
                match::args(match::has_value(sqrt(M_2_PI))),
                match::name("gpu::add")(
                    match::any_arg(0, 1)(match::name("gpu::mul")(match::either_arg(0, 1)(
                        match::args(match::has_value(0.044715f)), pow_fn()))))))));
    }

    auto matcher() const
    {
        return match::name("gpu::mul")(
            match::used_once(),
            match::either_arg(0, 1)(
                match::any().bind("x"),
                match::name("gpu::add")(match::any_arg(0, 1)(match::name("gpu::mul")(
                    match::either_arg(0, 1)(match::args(match::has_value(0.5f)), tanh_fn()))))));
    }

    void apply(program& p, match::matcher_result r) const
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto args  = ins->inputs();

        if(enabled(MIGRAPHX_DISABLE_FAST_GELU{}))
            p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back());
        else
            p.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
    }
};

struct find_add_gelu_new
{
    auto matcher() const
    {
        return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add")));
    }

    void apply(program& p, match::matcher_result r) const
    {
        auto add_ins = r.instructions["add"];
        auto ins     = r.result;
        auto args    = add_ins->inputs();
        move_standard_front(args);
        move_broadcasted_back(args);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_add_gelu_new{}, args);
    }
};

kahmed10's avatar
kahmed10 committed
448
449
450
451
452
453
struct find_add_clip
{
    auto matcher() const
    {
        return match::name(std::unordered_set<std::string>{"gpu::clip", "gpu::clipped_relu"})(
            match::arg(0)(match::any_of(match::name("gpu::add"),
kahmed10's avatar
kahmed10 committed
454
                                        match::name("gpu::triadd"),
kahmed10's avatar
kahmed10 committed
455
456
457
458
459
460
                                        match::any_of[match::inputs()](match::standard_shape()))
                              .bind("add")));
    }

    void apply(program& p, match::matcher_result r) const
    {
kahmed10's avatar
kahmed10 committed
461
462
463
464
465
466
467
468
469
470
        auto add_ins  = r.instructions["add"];
        auto ins      = r.result;
        auto ins_args = ins->inputs();
        auto add_args = add_ins->inputs();
        move_standard_front(add_args);
        move_broadcasted_back(add_args);

        // Use the allocation from the clip operator
        add_args.pop_back();
        add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end());
kahmed10's avatar
kahmed10 committed
471
        if(add_ins->name() == "gpu::add")
kahmed10's avatar
kahmed10 committed
472
            p.replace_instruction(ins, hip_add_clip{}, add_args);
kahmed10's avatar
kahmed10 committed
473
        else if(add_ins->name() == "gpu::triadd")
kahmed10's avatar
kahmed10 committed
474
            p.replace_instruction(ins, hip_triadd_clip{}, add_args);
kahmed10's avatar
kahmed10 committed
475
476
477
    }
};

478
struct find_add_unary
Paul's avatar
Paul committed
479
{
480
481
482
    std::string op_name;
    operation binary_add_op;
    operation ternary_add_op;
Paul's avatar
Paul committed
483
484
    auto matcher() const
    {
485
        return match::name(op_name)(match::arg(0)(
Paul's avatar
Paul committed
486
            match::used_once(),
Paul's avatar
Paul committed
487
            match::any_of(match::name("gpu::add"),
kahmed10's avatar
kahmed10 committed
488
                          match::name("gpu::triadd"),
Paul's avatar
Paul committed
489
490
491
                          match::any_of(match::name("@literal"),
                                        match::any_of[match::inputs()](match::standard_shape())))
                .bind("add")));
Paul's avatar
Paul committed
492
    }
Paul's avatar
Paul committed
493

Paul's avatar
Paul committed
494
495
    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
496
        auto add_ins = r.instructions["add"];
Paul's avatar
Paul committed
497
498
        auto ins     = r.result;
        auto args    = add_ins->inputs();
Paul's avatar
Paul committed
499
500
501
        move_standard_front(args);
        move_broadcasted_back(args);

Paul's avatar
Paul committed
502
        // Use the allocation from the relu operator
Paul's avatar
Paul committed
503
        args.back() = ins->inputs().back();
Paul's avatar
Paul committed
504
        if(add_ins->name() == "gpu::add")
505
            p.replace_instruction(ins, binary_add_op, args);
kahmed10's avatar
kahmed10 committed
506
        else if(add_ins->name() == "gpu::triadd")
507
            p.replace_instruction(ins, ternary_add_op, args);
Paul's avatar
Paul committed
508
509
510
    }
};

Paul's avatar
Paul committed
511
struct find_triadd
Paul's avatar
Paul committed
512
513
514
{
    auto matcher() const
    {
Paul's avatar
Paul committed
515
        return match::name("gpu::add")(match::either_arg(0, 1)(
Paul's avatar
Paul committed
516
            match::name("gpu::add")(match::used_once()).bind("add"),
Paul's avatar
Paul committed
517
518
519
            match::any(match::any_of(match::name("@literal"),
                                     match::any_of[match::inputs()](match::standard_shape())))
                .bind("input")));
Paul's avatar
Paul committed
520
521
522
523
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
524
525
526
527
        auto add_ins   = r.instructions["add"];
        auto input_ins = r.instructions["input"];
        auto ins       = r.result;
        auto args      = add_ins->inputs();
528
529
        assert(add_ins != input_ins);

Paul's avatar
Paul committed
530
531
532
533
        auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
        if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
            return;
        args.insert(args.begin(), input_ins);
Paul's avatar
Paul committed
534
535
536
        move_standard_front(args);
        move_broadcasted_back(args);

Paul's avatar
Paul committed
537
538
        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_triadd{}, args);
Paul's avatar
Paul committed
539
    }
Paul's avatar
Paul committed
540
541
};

Paul's avatar
Paul committed
542
543
544
545
struct find_mul_add
{
    auto matcher() const
    {
Paul's avatar
Paul committed
546
547
        return match::name("gpu::add")(match::either_arg(0, 1)(
            match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
Paul's avatar
Paul committed
548
549
550
551
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
552
553
554
555
        auto mul_ins = r.instructions["mul"];
        auto b_ins   = r.instructions["b"];
        auto ins     = r.result;
        auto args    = mul_ins->inputs();
Paul's avatar
Paul committed
556
557
558
559
560
561
562
563
564
565
566
        assert(mul_ins != b_ins);

        move_standard_front(args);
        move_broadcasted_back(args);
        args.insert(std::prev(args.end()), b_ins);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_mul_add{}, args);
    }
};

Paul's avatar
Paul committed
567
568
569
570
struct find_mul_add_relu
{
    auto matcher() const
    {
Paul's avatar
Paul committed
571
        return match::name("gpu::relu")(
kahmed10's avatar
kahmed10 committed
572
            match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add")));
Paul's avatar
Paul committed
573
574
575
576
577
    }

    void apply(program& p, match::matcher_result r) const
    {
        auto mul_add_ins = r.instructions["mul_add"];
Paul's avatar
Paul committed
578
579
        auto ins         = r.result;
        auto args        = mul_add_ins->inputs();
Paul's avatar
Paul committed
580
581
582
583
584
585
586

        // Use the allocation from the relu operator
        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_mul_add_relu{}, args);
    }
};

Paul's avatar
Paul committed
587
588
589
struct miopen_conv_bias
{
    op::convolution op;
590
591
592
    fusion f          = {};
    fusion::op_t conv = {};
    fusion::op_t bias = {};
Paul's avatar
Paul committed
593

Paul's avatar
Paul committed
594
595
596
597
598
599
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::convolution::reflect(self.op, f);
    }

Paul's avatar
Paul committed
600
601
602
603
604
605
606
    std::string name() const { return "gpu::conv_bias"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        // TODO: Check slices
        return op.compute_shape({inputs.at(0), inputs.at(1)});
    }
Paul's avatar
Paul committed
607
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
608
    {
Paul's avatar
Paul committed
609
        auto fargs  = make_fused_args();
Paul's avatar
Paul committed
610
        float alpha = 1;
Paul's avatar
Paul committed
611
        float beta  = 0;
Paul's avatar
Paul committed
612
613
        miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
        miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
Paul's avatar
Paul committed
614
        return f.execute(ctx, fargs, args[0], args[4]);
Paul's avatar
Paul committed
615
616
    }

617
618
619
620
621
622
623
624
    void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
    {
        f    = fusion(inputs[0]);
        conv = f.create_conv(op, inputs[1]);
        bias = f.create_bias(inputs[3]);
        f.compile(ctx);
    }

Paul's avatar
Paul committed
625
    shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
Paul's avatar
Paul committed
626
627
628
629
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Paul committed
630
};
631
MIGRAPHX_REGISTER_OP(miopen_conv_bias)
Paul's avatar
Paul committed
632

Paul's avatar
Add cbr  
Paul committed
633
634
635
struct miopen_conv_bias_relu
{
    op::convolution op;
636
637
638
639
    fusion f          = {};
    fusion::op_t conv = {};
    fusion::op_t bias = {};
    fusion::op_t relu = {};
Paul's avatar
Add cbr  
Paul committed
640

Paul's avatar
Paul committed
641
642
643
644
645
646
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::convolution::reflect(self.op, f);
    }

Paul's avatar
Add cbr  
Paul committed
647
648
649
650
651
652
653
    std::string name() const { return "gpu::conv_bias_relu"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        // TODO: Check slices
        return op.compute_shape({inputs.at(0), inputs.at(1)});
    }
Paul's avatar
Paul committed
654
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Paul's avatar
Add cbr  
Paul committed
655
656
    {
        auto fargs  = make_fused_args();
Paul's avatar
Paul committed
657
        float alpha = 1;
Paul's avatar
Paul committed
658
        float beta  = 0;
Paul's avatar
Add cbr  
Paul committed
659
660
        miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
        miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
Paul's avatar
Paul committed
661
662
        miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
        return f.execute(ctx, fargs, args[0], args[4]);
Paul's avatar
Add cbr  
Paul committed
663
    }
664
665
666
667
668
669
670
671
672
    void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
    {
        f    = fusion(inputs[0]);
        conv = f.create_conv(op, inputs[1]);
        bias = f.create_bias(inputs[3]);
        relu = f.create_relu();
        f.compile(ctx);
    }

Paul's avatar
Paul committed
673
    shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
Paul's avatar
Paul committed
674
675
676
677
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Add cbr  
Paul committed
678
};
679
MIGRAPHX_REGISTER_OP(miopen_conv_bias_relu)
Paul's avatar
Add cbr  
Paul committed
680

Paul's avatar
Paul committed
681
template <class... Ms>
Paul's avatar
Add cbr  
Paul committed
682
683
auto conv_bias(Ms... ms)
{
Paul's avatar
Paul committed
684
    return match::name("gpu::add")(
Paul's avatar
Paul committed
685
686
        match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
                                fusable_conv(match::used_once()).bind("conv")),
Paul's avatar
Paul committed
687
        ms...);
Paul's avatar
Paul committed
688
689
}

Paul's avatar
Paul committed
690
template <class Op>
Paul's avatar
Paul committed
691
692
693
694
695
696
697
698
699
700
701
void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
{
    auto conv_ins    = r.instructions["conv"];
    auto bias_ins    = r.instructions["bias"];
    auto ins         = r.result;
    auto input_ins   = conv_ins->inputs().at(0);
    auto weights_ins = conv_ins->inputs().at(1);
    auto conv_op     = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
    auto alloc_ins   = ins->inputs().back();
    auto old_ws_ins  = conv_ins->inputs().at(2);

702
    Op cb{conv_op};
Paul's avatar
Paul committed
703
    // TODO: Insert ws allocation
Paul's avatar
Paul committed
704
    auto ws = cb.get_workspace(ctx);
Paul's avatar
Paul committed
705
    (void)ws;
Paul's avatar
Paul committed
706
    p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
Paul's avatar
Add cbr  
Paul committed
707
708
}

Paul's avatar
Paul committed
709
struct find_conv_bias
Paul's avatar
Paul committed
710
{
Paul's avatar
Paul committed
711
    context* ctx = nullptr;
Paul's avatar
Paul committed
712
713
    auto matcher() const
    {
kahmed10's avatar
kahmed10 committed
714
715
        return conv_bias(match::none_of(
            match::output(match::name(std::unordered_set<std::string>{"gpu::relu"}))));
Paul's avatar
Paul committed
716
717
718
719
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
720
        apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r));
Paul's avatar
Paul committed
721
722
723
    }
};

Paul's avatar
Paul committed
724
struct find_conv_bias_relu
Paul's avatar
Add cbr  
Paul committed
725
726
{
    context* ctx = nullptr;
Paul's avatar
Paul committed
727
    auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
Paul's avatar
Add cbr  
Paul committed
728
729
730

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
731
        apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r));
Paul's avatar
Add cbr  
Paul committed
732
733
734
    }
};

Paul's avatar
Paul committed
735
736
void fuse_ops::apply(program& p) const
{
kahmed10's avatar
kahmed10 committed
737
738
    match::find_matches(p, find_gelu{}, find_gelu_new{});
    run_passes(p, {dead_code_elimination{}});
Paul's avatar
Paul committed
739
    match::find_matches(p, find_triadd{});
740
    match::find_matches(p,
kahmed10's avatar
kahmed10 committed
741
                        find_layernorm{},
742
743
744
745
746
747
748
749
750
751
                        find_conv_bias_relu{ctx},
                        find_conv_bias{ctx},
                        find_add_gelu{},
                        find_add_gelu_new{},
                        find_mul_add{},
                        find_mul_add_relu{},
                        find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}},
                        find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
                        find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
                        find_add_clip{});
Paul's avatar
Paul committed
752
    // clang-format on
Paul's avatar
Paul committed
753
}
Paul's avatar
Paul committed
754
755

} // namespace gpu
Paul's avatar
Paul committed
756
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
757
} // namespace migraphx