fuse_ops.cpp 23.8 KB
Newer Older
kahmed10's avatar
kahmed10 committed
1
2
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
Paul's avatar
Paul committed
3
4
5
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
kahmed10's avatar
kahmed10 committed
6
#include <migraphx/gpu/clip.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/gpu/convolution.hpp>
8
#include <migraphx/gpu/oper.hpp>
kahmed10's avatar
kahmed10 committed
9
10
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/mul.hpp>
11
#include <migraphx/gpu/gemm.hpp>
kahmed10's avatar
kahmed10 committed
12
#include <migraphx/gpu/device/layernorm.hpp>
kahmed10's avatar
kahmed10 committed
13
#include <migraphx/gpu/device/gelu.hpp>
Paul's avatar
Paul committed
14
#include <migraphx/gpu/device/mul_add.hpp>
15
16
17
18
19
#include <migraphx/gpu/device/add_clip.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_sigmoid.hpp>
#include <migraphx/gpu/device/add_tanh.hpp>
#include <migraphx/gpu/device/mul_add_relu.hpp>
Paul's avatar
Paul committed
20
#include <migraphx/gpu/device/add.hpp>
21
22
23
#include <migraphx/match/layernorm.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/match/gelu_tanh.hpp>
Paul's avatar
Paul committed
24
#include <migraphx/instruction.hpp>
25
#include <migraphx/register_op.hpp>
Paul's avatar
Paul committed
26
#include <migraphx/array.hpp>
kahmed10's avatar
kahmed10 committed
27
#include <migraphx/op/clip.hpp>
kahmed10's avatar
kahmed10 committed
28
#include <cmath>
Paul's avatar
Paul committed
29
30

namespace migraphx {
Paul's avatar
Paul committed
31
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
32
33
namespace gpu {

34
35
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION)

Paul's avatar
Paul committed
36
37
38
39
40
41
42
43
struct fusion
{
    using op_t = miopenFusionOpDescriptor_t;
    shared<fusion_plan_descriptor> fp;

    // Used as a temporary hack to keep descriptor references alive
    std::vector<std::shared_ptr<void>> storage;

Paul's avatar
Paul committed
44
    template <class T>
Paul's avatar
Paul committed
45
46
47
48
49
50
51
    auto keep_alive(T x)
    {
        auto result = share(std::move(x));
        storage.push_back(result);
        return result;
    }

52
53
    fusion() = default;

Paul's avatar
Paul committed
54
55
    fusion(const shape& input)
    {
56
        assert(input.standard());
Paul's avatar
Paul committed
57
        auto t = make_tensor(input);
Paul's avatar
Paul committed
58
        fp     = make_fusion_plan(t);
59
        assert(fp);
Paul's avatar
Paul committed
60
61
62
63
64
        keep_alive(std::move(t));
    }

    op_t operator[](std::size_t i) const
    {
65
        assert(fp);
Paul's avatar
Paul committed
66
67
68
        op_t result;
        auto status = miopenFusionPlanGetOp(fp.get(), i, &result);
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
69
            MIGRAPHX_THROW("Failed retrieving operator at " + std::to_string(i));
Paul's avatar
Paul committed
70
71
72
        return result;
    }

73
74
75
76
77
    auto get() const
    {
        assert(fp);
        return fp.get();
    }
Paul's avatar
Paul committed
78
79
80

    op_t create_bias(const shape& bias)
    {
81
        assert(fp);
Paul's avatar
Paul committed
82
        op_t result;
Paul's avatar
Paul committed
83
84
        auto b      = shape{bias.type(), {1, bias.lens().at(1), 1, 1}};
        auto t      = keep_alive(make_tensor(b));
Paul's avatar
Paul committed
85
86
        auto status = miopenCreateOpBiasForward(fp.get(), &result, t.get());
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
87
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
88
89
90
91
92
        return result;
    }

    op_t create_relu()
    {
93
        assert(fp);
Paul's avatar
Paul committed
94
95
96
        op_t result;
        auto status = miopenCreateOpActivationForward(fp.get(), &result, miopenActivationRELU);
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
97
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
98
99
100
101
102
        return result;
    }

    op_t create_conv(const op::convolution& op, const shape& weights)
    {
103
        assert(fp);
Paul's avatar
Paul committed
104
        op_t result;
Paul's avatar
Paul committed
105
106
        auto cd     = keep_alive(make_conv(op));
        auto t      = keep_alive(make_tensor(weights));
Paul's avatar
Paul committed
107
108
        auto status = miopenCreateOpConvForward(fp.get(), &result, cd.get(), t.get());
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
109
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
110
111
        return result;
    }
Paul's avatar
Paul committed
112
113
114

    shape get_workspace(context&)
    {
115
        // assert(fp);
Paul's avatar
Paul committed
116
117
118
119
120
        // TODO: Use zero workspace for now
        std::size_t ws_size = 0;
        // int algo_count = 1;
        // miopenConvFwdAlgorithm_t algo;
        // miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo);
Paul's avatar
Paul committed
121
122
        // miopenFusionPlanGetWorkSpaceSize(ctx.get_stream().get_miopen(), fp.get(), &ws_size,
        // algo);
Paul's avatar
Paul committed
123
124
125
126
127
        return shape{shape::int8_type, {ws_size}};
    }

    void compile(context& ctx)
    {
128
        assert(fp);
Paul's avatar
Paul committed
129
        auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get());
Paul's avatar
Paul committed
130
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
131
            MIGRAPHX_THROW("Compiling fusion plan failed");
Paul's avatar
Paul committed
132
133
    }

Paul's avatar
Paul committed
134
135
136
137
    argument execute(context& ctx,
                     const fused_operator_args& fargs,
                     const argument& x,
                     const argument& y) const
Paul's avatar
Paul committed
138
    {
139
        assert(fp);
Paul's avatar
Paul committed
140
141
        auto x_td   = make_tensor(x.get_shape());
        auto y_td   = make_tensor(y.get_shape());
Paul's avatar
Paul committed
142
        auto status = miopenExecuteFusionPlan(ctx.get_stream().get_miopen(),
Paul's avatar
Paul committed
143
144
145
146
147
148
                                              fp.get(),
                                              x_td.get(),
                                              x.implicit(),
                                              y_td.get(),
                                              y.implicit(),
                                              fargs.get());
Paul's avatar
Paul committed
149
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
150
            MIGRAPHX_THROW("Failed to execute fusion plan");
Paul's avatar
Paul committed
151
152
        return y;
    }
Paul's avatar
Paul committed
153
154
};

Paul's avatar
Paul committed
155
MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
Paul's avatar
Paul committed
156
157
{
    auto&& s = ins->get_shape();
Paul's avatar
Paul committed
158
159
    return s.broadcasted() and s.strides().size() == 4 and s.strides()[0] == 0 and
           s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
Paul's avatar
Paul committed
160
161
}

Paul's avatar
Paul committed
162
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
Paul's avatar
Paul committed
163
{
164
165
    if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{}))
        return false;
Paul's avatar
Paul committed
166
167
    if(ins->name() != "gpu::convolution")
        return false;
Paul's avatar
Paul committed
168
169
    if(ins->get_shape().type() != shape::float_type)
        return false;
Paul's avatar
Paul committed
170
171
172
    auto wei = ins->inputs().at(1)->get_shape();
    assert(wei.lens().size() == 4);
    auto conv = any_cast<miopen_convolution>(ins->get_operator());
Khalique's avatar
Khalique committed
173
    if(conv.op.group > 1)
Khalique's avatar
Khalique committed
174
        return false;
Paul's avatar
Paul committed
175
    if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
Paul's avatar
Paul committed
176
        return false;
177
178
179
180
181
182

    // Do not fuse non-symmetric input
    auto input_lens = ins->inputs().at(0)->get_shape().lens();
    if(input_lens[2] != input_lens[3] or wei.lens()[2] != wei.lens()[3])
        return false;

Paul's avatar
Paul committed
183
    auto op = conv.op;
184
185
    // Dont fuse winograd for non-3x3s since there is no fused windograd for those configs
    if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and
186
       wei.lens()[3] != 3 and contains({{1, 1}}, op.stride))
187
        return false;
kahmed10's avatar
kahmed10 committed
188
    return contains({{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}, op.padding) and
189
           contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation);
Paul's avatar
Paul committed
190
191
}

192
struct hip_triadd : ternary_device<hip_triadd, &device::add>
Paul's avatar
Paul committed
193
194
{
};
195
MIGRAPHX_REGISTER_OP(hip_triadd)
Paul's avatar
Paul committed
196

197
struct hip_triadd_clip : quinary_device<hip_triadd_clip, &device::add_clip>
kahmed10's avatar
kahmed10 committed
198
199
{
};
200
MIGRAPHX_REGISTER_OP(hip_triadd_clip)
kahmed10's avatar
kahmed10 committed
201

202
struct hip_add_clip : quaternary_device<hip_add_clip, &device::add_clip>
kahmed10's avatar
kahmed10 committed
203
204
{
};
205
MIGRAPHX_REGISTER_OP(hip_add_clip)
kahmed10's avatar
kahmed10 committed
206

207
struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu>
Paul's avatar
Paul committed
208
209
{
};
210
MIGRAPHX_REGISTER_OP(hip_triadd_relu)
Paul's avatar
Paul committed
211

212
213
214
struct hip_triadd_sigmoid : ternary_device<hip_triadd_sigmoid, &device::add_sigmoid>
{
};
215
MIGRAPHX_REGISTER_OP(hip_triadd_sigmoid)
216
217
218
219

struct hip_triadd_tanh : ternary_device<hip_triadd_tanh, &device::add_tanh>
{
};
220
MIGRAPHX_REGISTER_OP(hip_triadd_tanh)
221
222
223
224

struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
{
};
225
MIGRAPHX_REGISTER_OP(hip_add_relu)
226
227
228
229

struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid>
{
};
230
MIGRAPHX_REGISTER_OP(hip_add_sigmoid)
231
232

struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh>
Paul's avatar
Paul committed
233
234
{
};
235
MIGRAPHX_REGISTER_OP(hip_add_tanh)
Paul's avatar
Paul committed
236

kahmed10's avatar
kahmed10 committed
237
238
struct hip_layernorm : unary_device<hip_layernorm, &device::layernorm>
{
239
240
    // Empty finalize to skip dimension reduction
    void finalize(context&, const shape&, const std::vector<shape>&) {}
kahmed10's avatar
kahmed10 committed
241
};
242
MIGRAPHX_REGISTER_OP(hip_layernorm)
kahmed10's avatar
kahmed10 committed
243

Paul Fultz II's avatar
Paul Fultz II committed
244
245
246
247
248
249
250
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{
    // Empty finalize to skip dimension reduction
    void finalize(context&, const shape&, const std::vector<shape>&) {}
};
MIGRAPHX_REGISTER_OP(hip_triadd_layernorm)

kahmed10's avatar
kahmed10 committed
251
252
253
struct hip_gelu : unary_device<hip_gelu, &device::gelu>
{
};
254
MIGRAPHX_REGISTER_OP(hip_gelu)
kahmed10's avatar
kahmed10 committed
255
256
257
258

struct hip_add_gelu : binary_device<hip_add_gelu, &device::add_gelu>
{
};
259
MIGRAPHX_REGISTER_OP(hip_add_gelu)
kahmed10's avatar
kahmed10 committed
260
261
262
263

struct hip_gelu_new : unary_device<hip_gelu_new, &device::gelu_new>
{
};
264
MIGRAPHX_REGISTER_OP(hip_gelu_new)
kahmed10's avatar
kahmed10 committed
265
266
267
268

struct hip_add_gelu_new : binary_device<hip_add_gelu_new, &device::add_gelu_new>
{
};
269
MIGRAPHX_REGISTER_OP(hip_add_gelu_new)
kahmed10's avatar
kahmed10 committed
270

271
struct hip_mul_add : ternary_device<hip_mul_add, &device::mul_add>
Paul's avatar
Paul committed
272
273
{
};
274
MIGRAPHX_REGISTER_OP(hip_mul_add)
Paul's avatar
Paul committed
275

276
struct hip_mul_add_relu : ternary_device<hip_mul_add_relu, &device::mul_add_relu>
Paul's avatar
Paul committed
277
278
{
};
279
MIGRAPHX_REGISTER_OP(hip_mul_add_relu)
Paul's avatar
Paul committed
280

Paul's avatar
Paul committed
281
282
283
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
    // Ensure the last arguments is the broadcasted one
Paul's avatar
Paul committed
284
    auto last = std::prev(args.end());
Paul's avatar
Paul committed
285
286
    auto it =
        std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
Paul's avatar
Paul committed
287
288
    if(it != last)
        std::swap(*it, *std::prev(last));
Paul's avatar
Paul committed
289
290
291
292
293
}

void move_standard_front(std::vector<instruction_ref>& args)
{
    // Ensure the first arguments is the standard one
Paul's avatar
Paul committed
294
    auto last = std::prev(args.end());
Paul's avatar
Paul committed
295
296
    auto it =
        std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
Paul's avatar
Paul committed
297
    if(it != last)
Paul's avatar
Paul committed
298
299
300
        std::swap(*it, args.front());
}

301
302
auto gpu_name(const std::string& s) { return match::name("gpu::" + s); }

kahmed10's avatar
kahmed10 committed
303
304
struct find_layernorm
{
305
    auto matcher() const { return match::layernorm(&gpu_name); }
kahmed10's avatar
kahmed10 committed
306

307
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
308
309
310
311
312
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto args  = ins->inputs();

313
314
315
316
317
318
319
320
321
        // We dont fuse for non-standard layouts
        if(not x_ins->get_shape().standard())
            return;

        auto relements = x_ins->get_shape().lens().back();

        if(relements > 1024 or (relements % 4 != 0 and relements > 256))
            return;

kahmed10's avatar
kahmed10 committed
322
323
324
325
        p.replace_instruction(ins, hip_layernorm{}, x_ins, args.back());
    }
};

Paul Fultz II's avatar
Paul Fultz II committed
326
327
328
329
330
331
332
333
struct find_triadd_layernorm
{
    auto matcher() const
    {
        return match::name("gpu::layernorm")(match::arg(0)(match::name("gpu::triadd")(
            match::used_once(), match::all_of[match::inputs()](match::standard_shape()))));
    }

Shucai Xiao's avatar
Shucai Xiao committed
334
    void apply(module& p, const match::matcher_result& r) const
Paul Fultz II's avatar
Paul Fultz II committed
335
336
337
338
339
340
341
    {
        auto ins    = r.result;
        auto triadd = ins->inputs().front();
        p.replace_instruction(ins, hip_triadd_layernorm{}, triadd->inputs());
    }
};

kahmed10's avatar
kahmed10 committed
342
343
struct find_gelu
{
344
    auto matcher() const { return match::gelu_erf(&gpu_name); }
kahmed10's avatar
kahmed10 committed
345

346
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto args  = ins->inputs();

        p.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
    }
};

struct find_add_gelu
{
    auto matcher() const
    {
        return match::name("gpu::gelu")(match::arg(0)(match::name("gpu::add").bind("add")));
    }

363
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
    {
        auto add_ins = r.instructions["add"];
        auto ins     = r.result;
        auto args    = add_ins->inputs();
        move_standard_front(args);
        move_broadcasted_back(args);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_add_gelu{}, args);
    }
};

struct find_gelu_new
{
kahmed10's avatar
kahmed10 committed
378
    bool fast_math = true;
kahmed10's avatar
kahmed10 committed
379

380
    auto matcher() const { return match::gelu_tanh(&gpu_name); }
kahmed10's avatar
kahmed10 committed
381

382
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
383
384
385
386
387
    {
        auto ins   = r.result;
        auto x_ins = r.instructions["x"];
        auto args  = ins->inputs();

Paul Fultz II's avatar
Paul Fultz II committed
388
        if(fast_math)
kahmed10's avatar
kahmed10 committed
389
            p.replace_instruction(ins, hip_gelu{}, x_ins, args.back());
Paul Fultz II's avatar
Paul Fultz II committed
390
391
        else
            p.replace_instruction(ins, hip_gelu_new{}, x_ins, args.back());
kahmed10's avatar
kahmed10 committed
392
393
394
395
396
397
398
399
400
401
    }
};

struct find_add_gelu_new
{
    auto matcher() const
    {
        return match::name("gpu::gelu_new")(match::arg(0)(match::name("gpu::add").bind("add")));
    }

402
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
403
404
405
406
407
408
409
410
411
412
413
414
    {
        auto add_ins = r.instructions["add"];
        auto ins     = r.result;
        auto args    = add_ins->inputs();
        move_standard_front(args);
        move_broadcasted_back(args);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_add_gelu_new{}, args);
    }
};

kahmed10's avatar
kahmed10 committed
415
416
417
418
419
420
struct find_add_clip
{
    auto matcher() const
    {
        return match::name(std::unordered_set<std::string>{"gpu::clip", "gpu::clipped_relu"})(
            match::arg(0)(match::any_of(match::name("gpu::add"),
kahmed10's avatar
kahmed10 committed
421
                                        match::name("gpu::triadd"),
kahmed10's avatar
kahmed10 committed
422
423
424
425
                                        match::any_of[match::inputs()](match::standard_shape()))
                              .bind("add")));
    }

426
    void apply(module& p, match::matcher_result r) const
kahmed10's avatar
kahmed10 committed
427
    {
kahmed10's avatar
kahmed10 committed
428
429
430
431
432
433
434
435
436
437
        auto add_ins  = r.instructions["add"];
        auto ins      = r.result;
        auto ins_args = ins->inputs();
        auto add_args = add_ins->inputs();
        move_standard_front(add_args);
        move_broadcasted_back(add_args);

        // Use the allocation from the clip operator
        add_args.pop_back();
        add_args.insert(add_args.end(), std::next(ins_args.begin()), ins_args.end());
kahmed10's avatar
kahmed10 committed
438
        if(add_ins->name() == "gpu::add")
kahmed10's avatar
kahmed10 committed
439
            p.replace_instruction(ins, hip_add_clip{}, add_args);
kahmed10's avatar
kahmed10 committed
440
        else if(add_ins->name() == "gpu::triadd")
kahmed10's avatar
kahmed10 committed
441
            p.replace_instruction(ins, hip_triadd_clip{}, add_args);
kahmed10's avatar
kahmed10 committed
442
443
444
    }
};

445
struct find_add_unary
Paul's avatar
Paul committed
446
{
447
448
449
    std::string op_name;
    operation binary_add_op;
    operation ternary_add_op;
Paul's avatar
Paul committed
450
451
    auto matcher() const
    {
452
        return match::name(op_name)(match::arg(0)(
Paul's avatar
Paul committed
453
            match::used_once(),
Paul's avatar
Paul committed
454
            match::any_of(match::name("gpu::add"),
kahmed10's avatar
kahmed10 committed
455
                          match::name("gpu::triadd"),
Paul's avatar
Paul committed
456
457
458
                          match::any_of(match::name("@literal"),
                                        match::any_of[match::inputs()](match::standard_shape())))
                .bind("add")));
Paul's avatar
Paul committed
459
    }
Paul's avatar
Paul committed
460

461
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
462
    {
Paul's avatar
Paul committed
463
        auto add_ins = r.instructions["add"];
Paul's avatar
Paul committed
464
465
        auto ins     = r.result;
        auto args    = add_ins->inputs();
Paul's avatar
Paul committed
466
467
468
        move_standard_front(args);
        move_broadcasted_back(args);

Paul's avatar
Paul committed
469
        // Use the allocation from the relu operator
Paul's avatar
Paul committed
470
        args.back() = ins->inputs().back();
Paul's avatar
Paul committed
471
        if(add_ins->name() == "gpu::add")
472
            p.replace_instruction(ins, binary_add_op, args);
kahmed10's avatar
kahmed10 committed
473
        else if(add_ins->name() == "gpu::triadd")
474
            p.replace_instruction(ins, ternary_add_op, args);
Paul's avatar
Paul committed
475
476
477
    }
};

Paul's avatar
Paul committed
478
struct find_triadd
Paul's avatar
Paul committed
479
480
481
{
    auto matcher() const
    {
Paul's avatar
Paul committed
482
        return match::name("gpu::add")(match::either_arg(0, 1)(
Paul's avatar
Paul committed
483
            match::name("gpu::add")(match::used_once()).bind("add"),
Paul's avatar
Paul committed
484
485
486
            match::any(match::any_of(match::name("@literal"),
                                     match::any_of[match::inputs()](match::standard_shape())))
                .bind("input")));
Paul's avatar
Paul committed
487
488
    }

489
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
490
    {
Paul's avatar
Paul committed
491
492
493
494
        auto add_ins   = r.instructions["add"];
        auto input_ins = r.instructions["input"];
        auto ins       = r.result;
        auto args      = add_ins->inputs();
495

Paul's avatar
Paul committed
496
        auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
497
        if(std::count_if(args.begin(), args.end(), is_broadcasted) > 2)
Paul's avatar
Paul committed
498
499
            return;
        args.insert(args.begin(), input_ins);
Paul's avatar
Paul committed
500
501
502
        move_standard_front(args);
        move_broadcasted_back(args);

Paul's avatar
Paul committed
503
504
        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_triadd{}, args);
Paul's avatar
Paul committed
505
    }
Paul's avatar
Paul committed
506
507
};

Paul's avatar
Paul committed
508
509
510
511
struct find_mul_add
{
    auto matcher() const
    {
Paul's avatar
Paul committed
512
513
        return match::name("gpu::add")(match::either_arg(0, 1)(
            match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
Paul's avatar
Paul committed
514
515
    }

516
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
517
    {
Paul's avatar
Paul committed
518
519
520
521
        auto mul_ins = r.instructions["mul"];
        auto b_ins   = r.instructions["b"];
        auto ins     = r.result;
        auto args    = mul_ins->inputs();
Paul's avatar
Paul committed
522
523
524
525
526
527
528
529
530
531
532
        assert(mul_ins != b_ins);

        move_standard_front(args);
        move_broadcasted_back(args);
        args.insert(std::prev(args.end()), b_ins);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_mul_add{}, args);
    }
};

Paul's avatar
Paul committed
533
534
535
536
struct find_mul_add_relu
{
    auto matcher() const
    {
Paul's avatar
Paul committed
537
        return match::name("gpu::relu")(
kahmed10's avatar
kahmed10 committed
538
            match::arg(0)(match::name("gpu::mul_add")(match::used_once()).bind("mul_add")));
Paul's avatar
Paul committed
539
540
    }

541
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
542
543
    {
        auto mul_add_ins = r.instructions["mul_add"];
Paul's avatar
Paul committed
544
545
        auto ins         = r.result;
        auto args        = mul_add_ins->inputs();
Paul's avatar
Paul committed
546
547
548
549
550
551
552

        // Use the allocation from the relu operator
        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_mul_add_relu{}, args);
    }
};

Paul's avatar
Paul committed
553
554
555
struct miopen_conv_bias
{
    op::convolution op;
556
557
558
    fusion f          = {};
    fusion::op_t conv = {};
    fusion::op_t bias = {};
Paul's avatar
Paul committed
559

Paul's avatar
Paul committed
560
561
562
563
564
565
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::convolution::reflect(self.op, f);
    }

Paul's avatar
Paul committed
566
567
568
569
570
    std::string name() const { return "gpu::conv_bias"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        // TODO: Check slices
kahmed10's avatar
kahmed10 committed
571
        return op.normalize_compute_shape({inputs.at(0), inputs.at(1)});
Paul's avatar
Paul committed
572
    }
Paul's avatar
Paul committed
573
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
574
    {
Paul's avatar
Paul committed
575
        auto fargs  = make_fused_args();
Paul's avatar
Paul committed
576
        float alpha = 1;
Paul's avatar
Paul committed
577
        float beta  = 0;
Paul's avatar
Paul committed
578
579
        miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
        miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
Paul's avatar
Paul committed
580
        return f.execute(ctx, fargs, args[0], args[4]);
Paul's avatar
Paul committed
581
582
    }

583
584
585
586
587
588
589
590
    void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
    {
        f    = fusion(inputs[0]);
        conv = f.create_conv(op, inputs[1]);
        bias = f.create_bias(inputs[3]);
        f.compile(ctx);
    }

Paul's avatar
Paul committed
591
    shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
Paul's avatar
Paul committed
592
593
594
595
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Paul committed
596
};
597
MIGRAPHX_REGISTER_OP(miopen_conv_bias)
Paul's avatar
Paul committed
598

Paul's avatar
Add cbr  
Paul committed
599
600
601
struct miopen_conv_bias_relu
{
    op::convolution op;
602
603
604
605
    fusion f          = {};
    fusion::op_t conv = {};
    fusion::op_t bias = {};
    fusion::op_t relu = {};
Paul's avatar
Add cbr  
Paul committed
606

Paul's avatar
Paul committed
607
608
609
610
611
612
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::convolution::reflect(self.op, f);
    }

Paul's avatar
Add cbr  
Paul committed
613
614
615
616
617
    std::string name() const { return "gpu::conv_bias_relu"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        // TODO: Check slices
kahmed10's avatar
kahmed10 committed
618
        return op.normalize_compute_shape({inputs.at(0), inputs.at(1)});
Paul's avatar
Add cbr  
Paul committed
619
    }
Paul's avatar
Paul committed
620
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Paul's avatar
Add cbr  
Paul committed
621
622
    {
        auto fargs  = make_fused_args();
Paul's avatar
Paul committed
623
        float alpha = 1;
Paul's avatar
Paul committed
624
        float beta  = 0;
Paul's avatar
Add cbr  
Paul committed
625
626
        miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
        miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
Paul's avatar
Paul committed
627
628
        miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
        return f.execute(ctx, fargs, args[0], args[4]);
Paul's avatar
Add cbr  
Paul committed
629
    }
630
631
632
633
634
635
636
637
638
    void finalize(context& ctx, const shape&, const std::vector<shape>& inputs)
    {
        f    = fusion(inputs[0]);
        conv = f.create_conv(op, inputs[1]);
        bias = f.create_bias(inputs[3]);
        relu = f.create_relu();
        f.compile(ctx);
    }

Paul's avatar
Paul committed
639
    shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
Paul's avatar
Paul committed
640
641
642
643
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Add cbr  
Paul committed
644
};
645
MIGRAPHX_REGISTER_OP(miopen_conv_bias_relu)
Paul's avatar
Add cbr  
Paul committed
646

Paul's avatar
Paul committed
647
template <class... Ms>
Paul's avatar
Add cbr  
Paul committed
648
649
auto conv_bias(Ms... ms)
{
Paul's avatar
Paul committed
650
    return match::name("gpu::add")(
Paul's avatar
Paul committed
651
652
        match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
                                fusable_conv(match::used_once()).bind("conv")),
Paul's avatar
Paul committed
653
        ms...);
Paul's avatar
Paul committed
654
655
}

Paul's avatar
Paul committed
656
template <class Op>
657
void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
Paul's avatar
Paul committed
658
659
660
661
662
663
664
665
666
667
{
    auto conv_ins    = r.instructions["conv"];
    auto bias_ins    = r.instructions["bias"];
    auto ins         = r.result;
    auto input_ins   = conv_ins->inputs().at(0);
    auto weights_ins = conv_ins->inputs().at(1);
    auto conv_op     = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
    auto alloc_ins   = ins->inputs().back();
    auto old_ws_ins  = conv_ins->inputs().at(2);

668
    Op cb{conv_op};
Paul's avatar
Paul committed
669
    // TODO: Insert ws allocation
Paul's avatar
Paul committed
670
    auto ws = cb.get_workspace(ctx);
Paul's avatar
Paul committed
671
    (void)ws;
Paul's avatar
Paul committed
672
    p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
Paul's avatar
Add cbr  
Paul committed
673
674
}

Paul's avatar
Paul committed
675
struct find_conv_bias
Paul's avatar
Paul committed
676
{
Paul's avatar
Paul committed
677
    context* ctx = nullptr;
Paul's avatar
Paul committed
678
679
    auto matcher() const
    {
kahmed10's avatar
kahmed10 committed
680
681
        return conv_bias(match::none_of(
            match::output(match::name(std::unordered_set<std::string>{"gpu::relu"}))));
Paul's avatar
Paul committed
682
683
    }

684
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Paul committed
685
    {
Paul's avatar
Paul committed
686
        apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r));
Paul's avatar
Paul committed
687
688
689
    }
};

Paul's avatar
Paul committed
690
struct find_conv_bias_relu
Paul's avatar
Add cbr  
Paul committed
691
692
{
    context* ctx = nullptr;
Paul's avatar
Paul committed
693
    auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
Paul's avatar
Add cbr  
Paul committed
694

695
    void apply(module& p, match::matcher_result r) const
Paul's avatar
Add cbr  
Paul committed
696
    {
Paul's avatar
Paul committed
697
        apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r));
Paul's avatar
Add cbr  
Paul committed
698
699
700
    }
};

701
702
703
704
705
706
707
708
709
710
struct find_gemm_add
{
    auto matcher() const
    {
        return match::name("gpu::add")(
            match::all_of[match::inputs()](match::standard_shape()),
            match::either_arg(0, 1)(match::used_once().bind("c"),
                                    match::name("gpu::gemm")(match::nargs(3)).bind("gemm")));
    }

711
    void apply(module& p, match::matcher_result r) const
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
    {
        auto ins      = r.result;
        auto gemm_ins = r.instructions["gemm"];
        auto c_ins    = r.instructions["c"];

        auto gemm = any_cast<rocblas_gemm<op::dot>>(gemm_ins->get_operator());

        // Already fused gemm
        if(not float_equal(gemm.op.beta, 0))
            return;

        if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i) {
               return not i->get_shape().standard();
           }))
            return;

        auto inputs = gemm_ins->inputs();
        inputs.pop_back();

        auto copy_ins = c_ins;

        // Insert copy
        if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
        {
            copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
        }
        inputs.push_back(copy_ins);
        inputs.push_back(copy_ins);

        gemm.op.beta = 1;
        p.replace_instruction(ins, gemm, inputs);
    }
};

struct find_commutative_broadcast
{
    auto matcher() const
    {
        return match::name("gpu::add", "gpu::mul")(match::arg(1)(match::broadcast_shape()));
    }

753
    void apply(module& p, const match::matcher_result& r) const
754
755
756
757
758
759
760
761
762
    {
        auto ins  = r.result;
        auto args = ins->inputs();
        move_broadcasted_back(args);

        p.replace_instruction(ins, ins->get_operator(), args);
    }
};

763
void fuse_ops::apply(module& p) const
Paul's avatar
Paul committed
764
{
kahmed10's avatar
kahmed10 committed
765
    match::find_matches(p, find_gelu{}, find_gelu_new{fast_math});
kahmed10's avatar
kahmed10 committed
766
    run_passes(p, {dead_code_elimination{}});
Paul's avatar
Paul committed
767
    match::find_matches(p, find_triadd{});
768
    match::find_matches(p,
kahmed10's avatar
kahmed10 committed
769
                        find_layernorm{},
770
771
772
773
774
775
776
777
778
779
                        find_conv_bias_relu{ctx},
                        find_conv_bias{ctx},
                        find_add_gelu{},
                        find_add_gelu_new{},
                        find_mul_add{},
                        find_mul_add_relu{},
                        find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}},
                        find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
                        find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
                        find_add_clip{});
Paul Fultz II's avatar
Paul Fultz II committed
780
781
    run_passes(p, {dead_code_elimination{}});
    match::find_matches(p, find_triadd_layernorm{}, find_gemm_add{}, find_commutative_broadcast{});
Paul's avatar
Paul committed
782
}
Paul's avatar
Paul committed
783
784

} // namespace gpu
Paul's avatar
Paul committed
785
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
786
} // namespace migraphx