fuse_ops.cpp 19.4 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
kahmed10's avatar
kahmed10 committed
4
#include <migraphx/gpu/clip.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/gpu/convolution.hpp>
6
#include <migraphx/gpu/oper.hpp>
Paul's avatar
Paul committed
7
#include <migraphx/gpu/device/mul_add.hpp>
8
#include <migraphx/gpu/device/add_unary.hpp>
Paul's avatar
Paul committed
9
#include <migraphx/gpu/device/add.hpp>
Paul's avatar
Paul committed
10
#include <migraphx/instruction.hpp>
Paul's avatar
Paul committed
11
#include <migraphx/array.hpp>
kahmed10's avatar
kahmed10 committed
12
#include <migraphx/op/clip.hpp>
Paul's avatar
Paul committed
13
14

namespace migraphx {
Paul's avatar
Paul committed
15
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
16
17
namespace gpu {

18
19
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MIOPEN_FUSION)

Paul's avatar
Paul committed
20
21
22
23
24
25
26
27
struct fusion
{
    using op_t = miopenFusionOpDescriptor_t;
    shared<fusion_plan_descriptor> fp;

    // Used as a temporary hack to keep descriptor references alive
    std::vector<std::shared_ptr<void>> storage;

Paul's avatar
Paul committed
28
    template <class T>
Paul's avatar
Paul committed
29
30
31
32
33
34
35
36
37
38
39
    auto keep_alive(T x)
    {
        auto result = share(std::move(x));
        storage.push_back(result);
        return result;
    }

    fusion(const shape& input)
    // : fp(make_fusion_plan(input))
    {
        auto t = make_tensor(input);
Paul's avatar
Paul committed
40
        fp     = make_fusion_plan(t);
Paul's avatar
Paul committed
41
42
43
44
45
46
47
48
        keep_alive(std::move(t));
    }

    op_t operator[](std::size_t i) const
    {
        op_t result;
        auto status = miopenFusionPlanGetOp(fp.get(), i, &result);
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
49
            MIGRAPHX_THROW("Failed retrieving operator at " + std::to_string(i));
Paul's avatar
Paul committed
50
51
52
        return result;
    }

Paul's avatar
Paul committed
53
    auto get() const { return fp.get(); }
Paul's avatar
Paul committed
54
55
56
57

    op_t create_bias(const shape& bias)
    {
        op_t result;
Paul's avatar
Paul committed
58
59
        auto b      = shape{bias.type(), {1, bias.lens().at(1), 1, 1}};
        auto t      = keep_alive(make_tensor(b));
Paul's avatar
Paul committed
60
61
        auto status = miopenCreateOpBiasForward(fp.get(), &result, t.get());
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
62
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
63
64
65
66
67
68
69
70
        return result;
    }

    op_t create_relu()
    {
        op_t result;
        auto status = miopenCreateOpActivationForward(fp.get(), &result, miopenActivationRELU);
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
71
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
72
73
74
75
76
77
        return result;
    }

    op_t create_conv(const op::convolution& op, const shape& weights)
    {
        op_t result;
Paul's avatar
Paul committed
78
79
        auto cd     = keep_alive(make_conv(op));
        auto t      = keep_alive(make_tensor(weights));
Paul's avatar
Paul committed
80
81
        auto status = miopenCreateOpConvForward(fp.get(), &result, cd.get(), t.get());
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
82
            MIGRAPHX_THROW("Creating operator failed");
Paul's avatar
Paul committed
83
84
        return result;
    }
Paul's avatar
Paul committed
85
86
87
88
89
90
91
92

    shape get_workspace(context&)
    {
        // TODO: Use zero workspace for now
        std::size_t ws_size = 0;
        // int algo_count = 1;
        // miopenConvFwdAlgorithm_t algo;
        // miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo);
Paul's avatar
Paul committed
93
94
        // miopenFusionPlanGetWorkSpaceSize(ctx.get_stream().get_miopen(), fp.get(), &ws_size,
        // algo);
Paul's avatar
Paul committed
95
96
97
98
99
        return shape{shape::int8_type, {ws_size}};
    }

    void compile(context& ctx)
    {
Paul's avatar
Paul committed
100
        auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get());
Paul's avatar
Paul committed
101
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
102
            MIGRAPHX_THROW("Compiling fusion plan failed");
Paul's avatar
Paul committed
103
104
    }

Paul's avatar
Paul committed
105
106
107
108
    argument execute(context& ctx,
                     const fused_operator_args& fargs,
                     const argument& x,
                     const argument& y) const
Paul's avatar
Paul committed
109
    {
Paul's avatar
Paul committed
110
111
        auto x_td   = make_tensor(x.get_shape());
        auto y_td   = make_tensor(y.get_shape());
Paul's avatar
Paul committed
112
        auto status = miopenExecuteFusionPlan(ctx.get_stream().get_miopen(),
Paul's avatar
Paul committed
113
114
115
116
117
118
                                              fp.get(),
                                              x_td.get(),
                                              x.implicit(),
                                              y_td.get(),
                                              y.implicit(),
                                              fargs.get());
Paul's avatar
Paul committed
119
        if(status != miopenStatusSuccess)
Paul's avatar
Paul committed
120
            MIGRAPHX_THROW("Failed to execute fusion plan");
Paul's avatar
Paul committed
121
122
        return y;
    }
Paul's avatar
Paul committed
123
124
};

Paul's avatar
Paul committed
125
MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
Paul's avatar
Paul committed
126
127
{
    auto&& s = ins->get_shape();
Paul's avatar
Paul committed
128
129
    return s.broadcasted() and s.strides().size() == 4 and s.strides()[0] == 0 and
           s.strides()[1] != 0 and s.strides()[2] == 0 and s.strides()[3] == 0;
Paul's avatar
Paul committed
130
131
}

Paul's avatar
Paul committed
132
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
Paul's avatar
Paul committed
133
{
134
135
    if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{}))
        return false;
Paul's avatar
Paul committed
136
137
    if(ins->name() != "gpu::convolution")
        return false;
Paul's avatar
Paul committed
138
139
    if(ins->get_shape().type() != shape::float_type)
        return false;
Paul's avatar
Paul committed
140
141
142
    auto wei = ins->inputs().at(1)->get_shape();
    assert(wei.lens().size() == 4);
    auto conv = any_cast<miopen_convolution>(ins->get_operator());
Khalique's avatar
Khalique committed
143
    if(conv.op.group > 1)
Khalique's avatar
Khalique committed
144
        return false;
Paul's avatar
Paul committed
145
    if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
Paul's avatar
Paul committed
146
147
        return false;
    auto op = conv.op;
148
149
150
151
    // Dont fuse winograd for non-3x3s since there is no fused windograd for those configs
    if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and
       wei.lens()[3] != 3 and op.stride == make_array<size_t>(1, 1))
        return false;
Paul's avatar
Paul committed
152
153
    return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and
           contains({{0, 0}, {1, 1}}, op.stride) and op.dilation == make_array<size_t>(1, 1);
Paul's avatar
Paul committed
154
155
}

Paul's avatar
Paul committed
156
157
158
159
160
161
162
163
struct hip_triadd
{
    std::string name() const { return "hip::triadd"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(4);
        return inputs.front();
    }
Paul's avatar
Paul committed
164
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
165
    {
Paul's avatar
Paul committed
166
        device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
Paul's avatar
Paul committed
167
168
        return args.at(3);
    }
Paul's avatar
Paul committed
169
170
171
172
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Paul committed
173
174
};

kahmed10's avatar
kahmed10 committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
struct hip_triadd_clip
{
    op::clip op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::clip::reflect(self.op, f);
    }
    std::string name() const { return "hip::triadd_clip"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(4);
        return inputs.front();
    }
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
    {
        device::add_clip(ctx.get_stream().get(),
                         args.at(3),
                         args.at(0),
                         args.at(1),
                         args.at(2),
                         op.max_val,
                         op.min_val);
        return args.at(3);
    }
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
};

struct hip_add_clip
{
    op::clip op;

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::clip::reflect(self.op, f);
    }
    std::string name() const { return "hip::add_clip"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(3);
        return inputs.front();
    }
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
    {
        device::add_clip(
            ctx.get_stream().get(), args.at(2), args.at(0), args.at(1), op.max_val, op.min_val);
        return args.at(2);
    }
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
};

234
struct hip_triadd_relu : ternary_device<hip_triadd_relu, &device::add_relu>
Paul's avatar
Paul committed
235
236
237
{
};

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
struct hip_triadd_sigmoid : ternary_device<hip_triadd_sigmoid, &device::add_sigmoid>
{
};

struct hip_triadd_tanh : ternary_device<hip_triadd_tanh, &device::add_tanh>
{
};

struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
{
};

struct hip_add_sigmoid : binary_device<hip_add_relu, &device::add_sigmoid>
{
};

struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh>
Paul's avatar
Paul committed
255
256
257
{
};

Paul's avatar
Paul committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
struct hip_mul_add
{
    std::string name() const { return "hip::mul_add"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(4);
        return inputs.front();
    }
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
    {
        device::mul_add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
        return args.at(3);
    }
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
};

Paul's avatar
Paul committed
277
278
279
280
281
282
283
284
285
286
struct hip_mul_add_relu
{
    std::string name() const { return "hip::mul_add_relu"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(4);
        return inputs.front();
    }
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
    {
Paul's avatar
Paul committed
287
288
        device::mul_add_relu(
            ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
Paul's avatar
Paul committed
289
290
291
292
293
294
295
296
        return args.at(3);
    }
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
};

Paul's avatar
Paul committed
297
298
299
void move_broadcasted_back(std::vector<instruction_ref>& args)
{
    // Ensure the last arguments is the broadcasted one
Paul's avatar
Paul committed
300
    auto last = std::prev(args.end());
Paul's avatar
Paul committed
301
302
    auto it =
        std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().broadcasted(); });
Paul's avatar
Paul committed
303
304
    if(it != last)
        std::swap(*it, *std::prev(last));
Paul's avatar
Paul committed
305
306
307
308
309
}

void move_standard_front(std::vector<instruction_ref>& args)
{
    // Ensure the first arguments is the standard one
Paul's avatar
Paul committed
310
    auto last = std::prev(args.end());
Paul's avatar
Paul committed
311
312
    auto it =
        std::find_if(args.begin(), last, [](auto arg) { return arg->get_shape().standard(); });
Paul's avatar
Paul committed
313
    if(it != last)
Paul's avatar
Paul committed
314
315
316
        std::swap(*it, args.front());
}

kahmed10's avatar
kahmed10 committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
struct find_add_clip
{
    auto matcher() const
    {
        return match::name(std::unordered_set<std::string>{"gpu::clip", "gpu::clipped_relu"})(
            match::arg(0)(match::any_of(match::name("gpu::add"),
                                        match::name("hip::triadd"),
                                        match::any_of[match::inputs()](match::standard_shape()))
                              .bind("add")));
    }

    void apply(program& p, match::matcher_result r) const
    {
        auto add_ins = r.instructions["add"];
        auto ins     = r.result;
        auto&& op    = any_cast<gpu::hip_clip>(ins->get_operator()).op;
        auto args    = add_ins->inputs();
        move_standard_front(args);
        move_broadcasted_back(args);

        // Use the allocation from the relu operator
        args.back() = ins->inputs().back();
        if(add_ins->name() == "gpu::add")
            p.replace_instruction(ins, hip_add_clip{op}, args);
        else if(add_ins->name() == "hip::triadd")
            p.replace_instruction(ins, hip_triadd_clip{op}, args);
    }
};

346
struct find_add_unary
Paul's avatar
Paul committed
347
{
348
349
350
    std::string op_name;
    operation binary_add_op;
    operation ternary_add_op;
Paul's avatar
Paul committed
351
352
    auto matcher() const
    {
353
        return match::name(op_name)(match::arg(0)(
Paul's avatar
Paul committed
354
            match::used_once(),
Paul's avatar
Paul committed
355
356
357
358
359
            match::any_of(match::name("gpu::add"),
                          match::name("hip::triadd"),
                          match::any_of(match::name("@literal"),
                                        match::any_of[match::inputs()](match::standard_shape())))
                .bind("add")));
Paul's avatar
Paul committed
360
    }
Paul's avatar
Paul committed
361

Paul's avatar
Paul committed
362
363
    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
364
        auto add_ins = r.instructions["add"];
Paul's avatar
Paul committed
365
366
        auto ins     = r.result;
        auto args    = add_ins->inputs();
Paul's avatar
Paul committed
367
368
369
        move_standard_front(args);
        move_broadcasted_back(args);

Paul's avatar
Paul committed
370
        // Use the allocation from the relu operator
Paul's avatar
Paul committed
371
        args.back() = ins->inputs().back();
Paul's avatar
Paul committed
372
        if(add_ins->name() == "gpu::add")
373
            p.replace_instruction(ins, binary_add_op, args);
Paul's avatar
Paul committed
374
        else if(add_ins->name() == "hip::triadd")
375
            p.replace_instruction(ins, ternary_add_op, args);
Paul's avatar
Paul committed
376
377
378
    }
};

Paul's avatar
Paul committed
379
struct find_triadd
Paul's avatar
Paul committed
380
381
382
{
    auto matcher() const
    {
Paul's avatar
Paul committed
383
        return match::name("gpu::add")(match::either_arg(0, 1)(
Paul's avatar
Paul committed
384
            match::name("gpu::add")(match::used_once()).bind("add"),
Paul's avatar
Paul committed
385
386
387
            match::any(match::any_of(match::name("@literal"),
                                     match::any_of[match::inputs()](match::standard_shape())))
                .bind("input")));
Paul's avatar
Paul committed
388
389
390
391
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
392
393
394
395
        auto add_ins   = r.instructions["add"];
        auto input_ins = r.instructions["input"];
        auto ins       = r.result;
        auto args      = add_ins->inputs();
396
397
        assert(add_ins != input_ins);

Paul's avatar
Paul committed
398
399
400
401
        auto is_broadcasted = [](auto arg) { return arg->get_shape().broadcasted(); };
        if(std::count_if(args.begin(), args.end(), is_broadcasted) > 1)
            return;
        args.insert(args.begin(), input_ins);
Paul's avatar
Paul committed
402
403
404
        move_standard_front(args);
        move_broadcasted_back(args);

Paul's avatar
Paul committed
405
406
        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_triadd{}, args);
Paul's avatar
Paul committed
407
    }
Paul's avatar
Paul committed
408
409
};

Paul's avatar
Paul committed
410
411
412
413
struct find_mul_add
{
    auto matcher() const
    {
Paul's avatar
Paul committed
414
415
        return match::name("gpu::add")(match::either_arg(0, 1)(
            match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
Paul's avatar
Paul committed
416
417
418
419
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
420
421
422
423
        auto mul_ins = r.instructions["mul"];
        auto b_ins   = r.instructions["b"];
        auto ins     = r.result;
        auto args    = mul_ins->inputs();
Paul's avatar
Paul committed
424
425
426
427
428
429
430
431
432
433
434
        assert(mul_ins != b_ins);

        move_standard_front(args);
        move_broadcasted_back(args);
        args.insert(std::prev(args.end()), b_ins);

        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_mul_add{}, args);
    }
};

Paul's avatar
Paul committed
435
436
437
438
struct find_mul_add_relu
{
    auto matcher() const
    {
Paul's avatar
Paul committed
439
440
        return match::name("gpu::relu")(
            match::arg(0)(match::name("hip::mul_add")(match::used_once()).bind("mul_add")));
Paul's avatar
Paul committed
441
442
443
444
445
    }

    void apply(program& p, match::matcher_result r) const
    {
        auto mul_add_ins = r.instructions["mul_add"];
Paul's avatar
Paul committed
446
447
        auto ins         = r.result;
        auto args        = mul_add_ins->inputs();
Paul's avatar
Paul committed
448
449
450
451
452
453
454

        // Use the allocation from the relu operator
        args.back() = ins->inputs().back();
        p.replace_instruction(ins, hip_mul_add_relu{}, args);
    }
};

Paul's avatar
Paul committed
455
456
457
458
459
460
461
struct miopen_conv_bias
{
    op::convolution op;
    fusion f;
    fusion::op_t conv;
    fusion::op_t bias;

Paul's avatar
Paul committed
462
463
464
465
466
467
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::convolution::reflect(self.op, f);
    }

Paul's avatar
Paul committed
468
469
    miopen_conv_bias(op::convolution c, const shape& input, const shape& weights, const shape& b)
        : op(c), f(input)
Paul's avatar
Paul committed
470
    {
Paul's avatar
Paul committed
471
472
        conv = f.create_conv(op, weights);
        bias = f.create_bias(b);
Paul's avatar
Paul committed
473
474
475
476
477
478
479
480
481
    }

    std::string name() const { return "gpu::conv_bias"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        // TODO: Check slices
        return op.compute_shape({inputs.at(0), inputs.at(1)});
    }
Paul's avatar
Paul committed
482
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
483
    {
Paul's avatar
Paul committed
484
        auto fargs  = make_fused_args();
Paul's avatar
Paul committed
485
        float alpha = 1;
Paul's avatar
Paul committed
486
        float beta  = 0;
Paul's avatar
Paul committed
487
488
        miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
        miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
Paul's avatar
Paul committed
489
        return f.execute(ctx, fargs, args[0], args[4]);
Paul's avatar
Paul committed
490
491
    }

Paul's avatar
Paul committed
492
493
    void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); }
    shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
Paul's avatar
Paul committed
494
495
496
497
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Paul committed
498
499
};

Paul's avatar
Add cbr  
Paul committed
500
501
502
503
504
505
struct miopen_conv_bias_relu
{
    op::convolution op;
    fusion f;
    fusion::op_t conv;
    fusion::op_t bias;
Paul's avatar
Paul committed
506
    fusion::op_t relu;
Paul's avatar
Add cbr  
Paul committed
507

Paul's avatar
Paul committed
508
509
510
511
512
513
    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return op::convolution::reflect(self.op, f);
    }

Paul's avatar
Paul committed
514
515
516
517
518
    miopen_conv_bias_relu(op::convolution c,
                          const shape& input,
                          const shape& weights,
                          const shape& b)
        : op(c), f(input)
Paul's avatar
Add cbr  
Paul committed
519
    {
Paul's avatar
Paul committed
520
521
522
        conv = f.create_conv(op, weights);
        bias = f.create_bias(b);
        relu = f.create_relu();
Paul's avatar
Add cbr  
Paul committed
523
524
525
526
527
528
529
530
531
    }

    std::string name() const { return "gpu::conv_bias_relu"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(5);
        // TODO: Check slices
        return op.compute_shape({inputs.at(0), inputs.at(1)});
    }
Paul's avatar
Paul committed
532
    argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
Paul's avatar
Add cbr  
Paul committed
533
534
    {
        auto fargs  = make_fused_args();
Paul's avatar
Paul committed
535
        float alpha = 1;
Paul's avatar
Paul committed
536
        float beta  = 0;
Paul's avatar
Add cbr  
Paul committed
537
538
        miopenSetOpArgsConvForward(fargs.get(), conv, &alpha, &beta, args[1].implicit());
        miopenSetOpArgsBiasForward(fargs.get(), bias, &alpha, &beta, args[3].implicit());
Paul's avatar
Paul committed
539
540
        miopenSetOpArgsActivForward(fargs.get(), relu, &alpha, &beta, 0, 0, 0);
        return f.execute(ctx, fargs, args[0], args[4]);
Paul's avatar
Add cbr  
Paul committed
541
    }
Paul's avatar
Paul committed
542
543
    void finalize(context& ctx, const shape&, const std::vector<shape>&) { f.compile(ctx); }
    shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
Paul's avatar
Paul committed
544
545
546
547
    std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
    {
        return shapes.size() - 1;
    }
Paul's avatar
Add cbr  
Paul committed
548
549
};

Paul's avatar
Paul committed
550
template <class... Ms>
Paul's avatar
Add cbr  
Paul committed
551
552
auto conv_bias(Ms... ms)
{
Paul's avatar
Paul committed
553
    return match::name("gpu::add")(
Paul's avatar
Paul committed
554
555
        match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
                                fusable_conv(match::used_once()).bind("conv")),
Paul's avatar
Paul committed
556
        ms...);
Paul's avatar
Paul committed
557
558
}

Paul's avatar
Paul committed
559
template <class Op>
Paul's avatar
Paul committed
560
561
562
563
564
565
566
567
568
569
570
void apply_conv_bias(context& ctx, program& p, match::matcher_result r)
{
    auto conv_ins    = r.instructions["conv"];
    auto bias_ins    = r.instructions["bias"];
    auto ins         = r.result;
    auto input_ins   = conv_ins->inputs().at(0);
    auto weights_ins = conv_ins->inputs().at(1);
    auto conv_op     = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
    auto alloc_ins   = ins->inputs().back();
    auto old_ws_ins  = conv_ins->inputs().at(2);

Paul's avatar
Paul committed
571
    Op cb{conv_op, input_ins->get_shape(), weights_ins->get_shape(), bias_ins->get_shape()};
Paul's avatar
Paul committed
572
    // TODO: Insert ws allocation
Paul's avatar
Paul committed
573
    auto ws = cb.get_workspace(ctx);
Paul's avatar
Paul committed
574
    (void)ws;
Paul's avatar
Paul committed
575
    p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
Paul's avatar
Add cbr  
Paul committed
576
577
}

Paul's avatar
Paul committed
578
struct find_conv_bias
Paul's avatar
Paul committed
579
{
Paul's avatar
Paul committed
580
    context* ctx = nullptr;
Paul's avatar
Paul committed
581
582
    auto matcher() const
    {
kahmed10's avatar
kahmed10 committed
583
584
        return conv_bias(match::none_of(
            match::output(match::name(std::unordered_set<std::string>{"gpu::relu"}))));
Paul's avatar
Paul committed
585
586
587
588
    }

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
589
        apply_conv_bias<miopen_conv_bias>(*ctx, p, std::move(r));
Paul's avatar
Paul committed
590
591
592
    }
};

Paul's avatar
Paul committed
593
struct find_conv_bias_relu
Paul's avatar
Add cbr  
Paul committed
594
595
{
    context* ctx = nullptr;
Paul's avatar
Paul committed
596
    auto matcher() const { return match::name("gpu::relu")(match::arg(0)(conv_bias())); }
Paul's avatar
Add cbr  
Paul committed
597
598
599

    void apply(program& p, match::matcher_result r) const
    {
Paul's avatar
Paul committed
600
        apply_conv_bias<miopen_conv_bias_relu>(*ctx, p, std::move(r));
Paul's avatar
Add cbr  
Paul committed
601
602
603
    }
};

Paul's avatar
Paul committed
604
605
void fuse_ops::apply(program& p) const
{
Paul's avatar
Paul committed
606
    // clang-format off
Paul's avatar
Paul committed
607
    match::find_matches(p, find_triadd{});
Paul's avatar
Paul committed
608
    match::find_matches(p, 
Paul's avatar
Paul committed
609
610
        find_conv_bias_relu{ctx},
        find_conv_bias{ctx},
Paul's avatar
Paul committed
611
612
        find_mul_add{},
        find_mul_add_relu{},
613
614
        find_add_unary{"gpu::relu", hip_add_relu{}, hip_triadd_relu{}},
        find_add_unary{"gpu::sigmoid", hip_add_sigmoid{}, hip_triadd_sigmoid{}},
kahmed10's avatar
kahmed10 committed
615
616
        find_add_unary{"gpu::tanh", hip_add_tanh{}, hip_triadd_tanh{}},
        find_add_clip{}
Paul's avatar
Paul committed
617
618
    );
    // clang-format on
Paul's avatar
Paul committed
619
}
Paul's avatar
Paul committed
620
621

} // namespace gpu
Paul's avatar
Paul committed
622
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
623
} // namespace migraphx