lowering.cpp 17 KB
Newer Older
1
#include <rocblas.h>
Paul's avatar
Paul committed
2
#include <migraph/gpu/lowering.hpp>
Paul's avatar
Paul committed
3
4
5
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
Paul's avatar
Paul committed
6
#include <migraph/generate.hpp>
Paul's avatar
Paul committed
7
#include <migraph/shape_for_each.hpp>
Paul's avatar
Paul committed
8
9
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
Paul's avatar
Paul committed
10
#include <migraph/dfor.hpp>
11
#include <migraph/gpu/device/contiguous.hpp>
Paul's avatar
Paul committed
12
#include <migraph/gpu/device/add.hpp>
Paul's avatar
Paul committed
13
#include <migraph/iterator_for.hpp>
Paul's avatar
Paul committed
14
15
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/context.hpp>
Paul's avatar
Paul committed
16
#include <utility>
Paul's avatar
Paul committed
17
18

namespace migraph {
Paul's avatar
Paul committed
19
namespace gpu {
Paul's avatar
Paul committed
20

21
22
23
24
25
26
struct miopen_batch_norm_inference
{
    batch_norm_inference op;

    std::string name() const { return "gpu::batch_norm_inference"; }

Paul's avatar
Paul committed
27
    shape compute_shape(const std::vector<shape>& inputs) const
28
29
30
31
32
33
    {
        check_shapes{inputs, *this}.has(6);
        return op.compute_shape(
            {inputs.at(0), inputs.at(1), inputs.at(2), inputs.at(3), inputs.at(4)});
    }

Paul's avatar
Paul committed
34
35
    argument
    compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
36
    {
wsttiger's avatar
wsttiger committed
37
38
        auto x_desc  = make_tensor(args[0].get_shape());
        auto y_desc  = make_tensor(output_shape);
39
        auto bn_desc = make_tensor(args[3].get_shape());
40
41
42
43
44
45
46
47
48
49
50

        float alpha = 1.0, beta = 0.0f;

        miopenBatchNormalizationForwardInference(ctx.handle.get(),
                                                 miopenBatchNormMode_t(op.bn_mode),
                                                 &alpha,
                                                 &beta,
                                                 x_desc.get(),
                                                 args[0].implicit(),
                                                 y_desc.get(),
                                                 args[5].implicit(),
51
                                                 bn_desc.get(),
52
53
                                                 args[1].implicit(),
                                                 args[2].implicit(),
Paul's avatar
Paul committed
54
55
                                                 args[3].implicit(),
                                                 args[4].implicit(),
56
                                                 op.epsilon);
57
58
59
60
61

        return args[5];
    }
};

Paul's avatar
Paul committed
62
63
64
struct miopen_convolution
{
    convolution op;
Paul's avatar
Paul committed
65
    shared<convolution_descriptor> cd;
Paul's avatar
Paul committed
66
    miopenConvFwdAlgorithm_t algo{};
Paul's avatar
Paul committed
67

Paul's avatar
Paul committed
68
    std::string name() const { return "gpu::convolution"; }
Paul's avatar
Paul committed
69
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
70
    {
Paul's avatar
Paul committed
71
        check_shapes{inputs, *this}.has(4).standard();
Paul's avatar
Paul committed
72
        return op.compute_shape({inputs.at(0), inputs.at(1)});
Paul's avatar
Paul committed
73
    }
Paul's avatar
Paul committed
74
75
    argument
    compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
Paul's avatar
Paul committed
76
    {
Paul's avatar
Paul committed
77
78
        auto x_desc = make_tensor(args[0].get_shape());
        auto w_desc = make_tensor(args[1].get_shape());
Paul's avatar
Paul committed
79
80
        auto y_desc = make_tensor(output_shape);

Paul's avatar
Paul committed
81
        float alpha = 1, beta = 0;
Paul's avatar
Paul committed
82
        miopenConvolutionForward(ctx.handle.get(),
Paul's avatar
Paul committed
83
                                 &alpha,
Paul's avatar
Paul committed
84
                                 x_desc.get(),
Paul's avatar
Paul committed
85
                                 args[0].implicit(),
Paul's avatar
Paul committed
86
                                 w_desc.get(),
Paul's avatar
Paul committed
87
                                 args[1].implicit(),
Paul's avatar
Paul committed
88
                                 cd.get(),
Paul's avatar
Paul committed
89
                                 algo,
Paul's avatar
Paul committed
90
                                 &beta,
Paul's avatar
Paul committed
91
                                 y_desc.get(),
Paul's avatar
Paul committed
92
                                 args[3].implicit(),
Paul's avatar
Paul committed
93
                                 args[2].implicit(),
Paul's avatar
Paul committed
94
95
                                 args[2].get_shape().bytes());
        return args[3];
Paul's avatar
Paul committed
96
    }
Paul's avatar
Paul committed
97

Paul's avatar
Paul committed
98
    shape compile(context& ctx, const shape& output_shape, std::vector<instruction_ref> inputs)
Paul's avatar
Paul committed
99
    {
Paul's avatar
Paul committed
100
        shape workspace_shape{};
Paul's avatar
Paul committed
101
102
103
104
        auto x_desc = make_tensor(inputs[0]->get_shape());
        auto w_desc = make_tensor(inputs[1]->get_shape());
        auto y_desc = make_tensor(output_shape);

Paul's avatar
Paul committed
105
        std::size_t workspace_size = 0;
Paul's avatar
Paul committed
106
        miopenConvolutionForwardGetWorkSpaceSize(
Paul's avatar
Paul committed
107
            ctx.handle.get(), w_desc.get(), x_desc.get(), cd.get(), y_desc.get(), &workspace_size);
Paul's avatar
Paul committed
108
109
        workspace_shape = shape{shape::int8_type, {workspace_size}};

Paul's avatar
Paul committed
110
111
112
        auto x         = to_gpu(generate_argument(inputs[0]->get_shape()));
        auto w         = to_gpu(generate_argument(inputs[1]->get_shape()));
        auto y         = to_gpu(generate_argument(output_shape));
Paul's avatar
Paul committed
113
        auto workspace = allocate_gpu(workspace_shape);
Paul's avatar
Paul committed
114

Paul's avatar
Paul committed
115
        int algo_count = 1;
Paul's avatar
Paul committed
116
117
118
119
120
121
122
123
124
125
126
127
        miopenConvAlgoPerf_t perf;
        miopenFindConvolutionForwardAlgorithm(ctx.handle.get(),
                                              x_desc.get(),
                                              x.implicit(),
                                              w_desc.get(),
                                              w.implicit(),
                                              cd.get(),
                                              y_desc.get(),
                                              y.implicit(),
                                              1,
                                              &algo_count,
                                              &perf,
Paul's avatar
Paul committed
128
129
                                              workspace.implicit(),
                                              workspace_size,
Paul's avatar
Paul committed
130
131
                                              false);
        algo = perf.fwd_algo;
Paul's avatar
Paul committed
132
133
        return algo == miopenConvolutionFwdAlgoWinograd ? shape{shape::int8_type, {0}}
                                                        : workspace_shape;
Paul's avatar
Paul committed
134
    }
Paul's avatar
Paul committed
135
136
};

Paul's avatar
Paul committed
137
138
139
140
141
struct miopen_pooling
{
    pooling op;
    shared<pooling_descriptor> pd;

Paul's avatar
Paul committed
142
    std::string name() const { return "gpu::pooling"; }
Paul's avatar
Paul committed
143
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
144
    {
Paul's avatar
Paul committed
145
        check_shapes{inputs, *this}.has(2).standard();
Paul's avatar
Paul committed
146
        return op.compute_shape({inputs.at(0)});
Paul's avatar
Paul committed
147
    }
Paul's avatar
Paul committed
148
149
    argument
    compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
Paul's avatar
Paul committed
150
    {
Paul's avatar
Paul committed
151
        auto x_desc = make_tensor(args[0].get_shape());
Paul's avatar
Paul committed
152
153
154
155
        auto y_desc = make_tensor(output_shape);

        float alpha = 1, beta = 0;

Paul's avatar
Paul committed
156
        miopenPoolingForward(ctx.handle.get(),
Paul's avatar
Paul committed
157
158
159
                             pd.get(),
                             &alpha,
                             x_desc.get(),
Paul's avatar
Paul committed
160
                             args[0].implicit(),
Paul's avatar
Paul committed
161
162
                             &beta,
                             y_desc.get(),
Paul's avatar
Paul committed
163
                             args[1].implicit(),
Paul's avatar
Paul committed
164
165
166
                             false,
                             nullptr,
                             0);
Paul's avatar
Paul committed
167

Paul's avatar
Paul committed
168
        return args[1];
Paul's avatar
Paul committed
169
170
171
    }
};

Paul's avatar
Paul committed
172
struct hip_add
Paul's avatar
Paul committed
173
{
Paul's avatar
Paul committed
174
    std::string name() const { return "gpu::add"; }
Paul's avatar
Paul committed
175
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
176
    {
Paul's avatar
Paul committed
177
        // check_shapes{inputs, *this}.has(3).standard();
Paul's avatar
Paul committed
178
        check_shapes{inputs, *this}.has(3);
Paul's avatar
Paul committed
179
        return inputs.at(0);
Paul's avatar
Paul committed
180
181
    }

Paul's avatar
Paul committed
182
    argument compute(context&, const shape&, const std::vector<argument>& args) const
Paul's avatar
Paul committed
183
    {
Paul's avatar
Paul committed
184
        device::add(args[2], args[0], args[1]);
Paul's avatar
Paul committed
185
        return args[2];
Paul's avatar
Paul committed
186
187
188
189
190
    }
};

struct miopen_add
{
Paul's avatar
Paul committed
191
    std::string name() const { return "gpu::add"; }
Paul's avatar
Paul committed
192
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
193
    {
Paul's avatar
Paul committed
194
        check_shapes{inputs, *this}.has(3).not_broadcasted();
Paul's avatar
Paul committed
195
        return inputs.at(0);
Paul's avatar
Paul committed
196
197
    }

Paul's avatar
Paul committed
198
199
    argument
    compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
Paul's avatar
Paul committed
200
    {
Paul's avatar
Paul committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        float alpha = 1, beta = 0;
        auto a_desc = make_tensor(args[0].get_shape());
        auto b_desc = make_tensor(args[1].get_shape());
        auto c_desc = make_tensor(output_shape);
        miopenOpTensor(ctx.handle.get(),
                       miopenTensorOpAdd,
                       &alpha,
                       a_desc.get(),
                       args[0].implicit(),
                       &alpha,
                       b_desc.get(),
                       args[1].implicit(),
                       &beta,
                       c_desc.get(),
                       args[2].implicit());
        return args[2];
Paul's avatar
Paul committed
217
218
219
    }
};

Paul's avatar
Paul committed
220
221
222
struct miopen_gemm
{
    gemm op;
223
    std::string name() const { return "gpu::gemm"; }
Paul's avatar
Paul committed
224
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
225
    {
Paul's avatar
Paul committed
226
227
        check_shapes{inputs, *this}.has(3);
        return op.compute_shape({inputs.at(0), inputs.at(1)});
Paul's avatar
Paul committed
228
    }
Paul's avatar
Paul committed
229
230
    argument
    compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
Paul's avatar
Paul committed
231
    {
232
233
        float alpha     = 1.0f;
        float beta      = 0.0f;
Paul's avatar
Paul committed
234
235
        bool transa     = args[0].get_shape().transposed();
        bool transb     = args[1].get_shape().transposed();
236
237
238
        rocblas_int lda = args[0].get_shape().strides()[transa ? 1 : 0];
        rocblas_int ldb = args[1].get_shape().strides()[transb ? 1 : 0];
        rocblas_int ldc = args[2].get_shape().strides()[0];
239
240
241
        rocblas_int m   = output_shape.lens()[0];
        rocblas_int n   = output_shape.lens()[1];
        rocblas_int k   = args[0].get_shape().lens()[1];
242
        rocblas_sgemm(ctx.rbhandle.get(),
243
244
                      transb ? rocblas_operation_transpose : rocblas_operation_none,
                      transa ? rocblas_operation_transpose : rocblas_operation_none,
245
246
247
248
249
250
251
252
253
254
255
256
                      n,
                      m,
                      k,
                      &alpha,
                      args[1].implicit(),
                      ldb,
                      args[0].implicit(),
                      lda,
                      &beta,
                      args[2].implicit(),
                      ldc);
        return args[2];
Paul's avatar
Paul committed
257
258
259
    }
};

260
261
262
struct miopen_contiguous
{
    contiguous op;
Paul's avatar
Paul committed
263
    std::string name() const { return "gpu::contiguous"; }
Paul's avatar
Paul committed
264
    shape compute_shape(const std::vector<shape>& inputs) const
265
266
267
268
    {
        check_shapes{inputs, *this}.has(2);
        return op.compute_shape({inputs.at(0)});
    }
Paul's avatar
Paul committed
269
    argument compute(context&, shape output_shape, const std::vector<argument>& args) const
270
    {
Paul's avatar
Paul committed
271
272
        assert(output_shape == args[1].get_shape());
        assert(output_shape.standard());
Paul's avatar
Paul committed
273
        (void)output_shape;
274
        device::contiguous(args.at(1), args.at(0));
275
        return args.at(1);
276
277
278
    }
};

Paul's avatar
Paul committed
279
280
281
struct miopen_relu
{
    shared<activation_descriptor> ad;
Paul's avatar
Paul committed
282
    std::string name() const { return "gpu::relu"; }
Paul's avatar
Paul committed
283
    shape compute_shape(const std::vector<shape>& inputs) const
Paul's avatar
Paul committed
284
    {
Paul's avatar
Paul committed
285
        check_shapes{inputs, *this}.has(2).not_broadcasted();
Paul's avatar
Paul committed
286
        return inputs.at(1);
Paul's avatar
Paul committed
287
288
    }

Paul's avatar
Paul committed
289
290
    argument
    compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
Paul's avatar
Paul committed
291
292
    {
        float alpha = 1, beta = 0;
Paul's avatar
Paul committed
293
        auto x_desc = make_tensor(args[0].get_shape());
Paul's avatar
Paul committed
294
        auto y_desc = make_tensor(output_shape);
Paul's avatar
Paul committed
295
        miopenActivationForward(ctx.handle.get(),
Paul's avatar
Paul committed
296
297
298
                                ad.get(),
                                &alpha,
                                x_desc.get(),
Paul's avatar
Paul committed
299
                                args[0].implicit(),
Paul's avatar
Paul committed
300
301
                                &beta,
                                y_desc.get(),
Paul's avatar
Paul committed
302
                                args[1].implicit());
Paul's avatar
Paul committed
303

Paul's avatar
Paul committed
304
        return args[1];
Paul's avatar
Paul committed
305
306
307
    }
};

Paul's avatar
Paul committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
struct miopen_softmax
{
    softmax op;
    std::string name() const { return "gpu::softmax"; }
    shape compute_shape(const std::vector<shape>& inputs) const
    {
        check_shapes{inputs, *this}.has(2).standard();
        return inputs.at(1);
    }

    argument
    compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
    {
        float alpha = 1, beta = 0;
        auto x_desc = make_tensor(args[0].get_shape());
        auto y_desc = make_tensor(output_shape);
        miopenSoftmaxForward(ctx.handle.get(),
                                &alpha,
                                x_desc.get(),
                                args[0].implicit(),
                                &beta,
                                y_desc.get(),
                                args[1].implicit());

        return args[1];
    }
};

Paul's avatar
Paul committed
336
337
struct miopen_apply
{
Paul's avatar
Paul committed
338
    program* prog = nullptr;
Paul's avatar
Paul committed
339
    context ctx{};
Paul's avatar
Paul committed
340

Paul's avatar
Paul committed
341
342
343
344
345
346
347
    void check_shape(shape x, instruction_ref i)
    {
        assert(x == i->get_shape());
        (void)x;
        (void)i;
    }

Paul's avatar
Paul committed
348
349
    void apply()
    {
Paul's avatar
Paul committed
350
351
        for(auto it = prog->begin(); it != prog->end(); it++)
        {
Paul's avatar
Paul committed
352
            auto s = it->get_shape();
Paul's avatar
Paul committed
353
            if(it->name() == "convolution")
Paul's avatar
Paul committed
354
            {
Paul's avatar
Paul committed
355
                check_shape(s, apply_convolution(it));
Paul's avatar
Paul committed
356
            }
Paul's avatar
Paul committed
357
            else if(it->name() == "activation")
Paul's avatar
Paul committed
358
            {
Paul's avatar
Paul committed
359
                check_shape(s, apply_activation(it));
Paul's avatar
Paul committed
360
            }
Paul's avatar
Paul committed
361
            else if(it->name() == "pooling")
Paul's avatar
Paul committed
362
            {
Paul's avatar
Paul committed
363
                check_shape(s, apply_pooling(it));
Paul's avatar
Paul committed
364
            }
Paul's avatar
Paul committed
365
            else if(it->name() == "add")
Paul's avatar
Paul committed
366
            {
Paul's avatar
Paul committed
367
                check_shape(s, apply_add(it));
Paul's avatar
Paul committed
368
            }
Paul's avatar
Paul committed
369
            else if(it->name() == "gemm")
Paul's avatar
Paul committed
370
            {
Paul's avatar
Paul committed
371
                check_shape(s, apply_gemm(it));
Paul's avatar
Paul committed
372
            }
Paul's avatar
Paul committed
373
            else if(it->name() == "contiguous")
374
            {
Paul's avatar
Paul committed
375
                check_shape(s, apply_contiguous(it));
376
            }
Paul's avatar
Paul committed
377
            else if(it->name() == "batch_norm_inference")
378
            {
Paul's avatar
Paul committed
379
                check_shape(s, apply_batch_norm_inference(it));
380
            }
Paul's avatar
Paul committed
381
382
383
384
            else if(it->name() == "softmax")
            {
                check_shape(s, apply_softmax(it));
            }
Paul's avatar
Paul committed
385
386
387
        }
    }

Paul's avatar
Paul committed
388
    instruction_ref insert_allocation(instruction_ref ins, const shape& s, std::string tag = "")
Paul's avatar
Paul committed
389
    {
Paul's avatar
Paul committed
390
        if(ins == --prog->end() and tag.empty())
Paul's avatar
Paul committed
391
392
393
394
395
        {
            return prog->add_parameter("output", s);
        }
        else
        {
Paul's avatar
Paul committed
396
            auto is     = prog->add_outline(s);
Paul's avatar
Paul committed
397
            auto result = prog->insert_instruction(ins, hip_allocate{std::move(tag)}, is);
Paul's avatar
Paul committed
398
399
400
401
            return result;
        }
    }

Paul's avatar
Paul committed
402
    instruction_ref apply_convolution(instruction_ref ins)
Paul's avatar
Paul committed
403
    {
404
        auto&& op = any_cast<convolution>(ins->get_operator());
Paul's avatar
Paul committed
405

Paul's avatar
Paul committed
406
        auto conv = miopen_convolution{op, make_conv(op)};
Paul's avatar
Paul committed
407
        auto ws   = conv.compile(ctx, ins->get_shape(), ins->inputs());
Paul's avatar
Paul committed
408

409
        auto workspace = insert_allocation(ins, ws, "workspace");
Paul's avatar
Paul committed
410
        auto output    = insert_allocation(ins, ins->get_shape());
Paul's avatar
Paul committed
411

Paul's avatar
Paul committed
412
        return prog->replace_instruction(
Paul's avatar
Paul committed
413
            ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
Paul's avatar
Paul committed
414
415
    }

Paul's avatar
Paul committed
416
    instruction_ref apply_pooling(instruction_ref ins)
Paul's avatar
Paul committed
417
    {
418
        auto&& op   = any_cast<pooling>(ins->get_operator());
Paul's avatar
Paul committed
419
        auto pd     = make_pooling(op);
Paul's avatar
Paul committed
420
        auto output = insert_allocation(ins, ins->get_shape());
Paul's avatar
Paul committed
421

Paul's avatar
Paul committed
422
        return prog->replace_instruction(
Paul's avatar
Paul committed
423
            ins, miopen_pooling{op, std::move(pd)}, ins->inputs().at(0), output);
Paul's avatar
Paul committed
424
425
    }

Paul's avatar
Paul committed
426
    instruction_ref apply_activation(instruction_ref ins)
Paul's avatar
Paul committed
427
    {
428
        auto&& op = any_cast<activation>(ins->get_operator());
Paul's avatar
Paul committed
429
430
        auto ad   = make_relu();
        if(op.mode == "relu")
Paul's avatar
Paul committed
431
        {
Paul's avatar
Paul committed
432
            auto output = insert_allocation(ins, ins->get_shape());
Paul's avatar
Paul committed
433
            return prog->replace_instruction(
Paul's avatar
Paul committed
434
                ins, miopen_relu{std::move(ad)}, ins->inputs().at(0), output);
Paul's avatar
Paul committed
435
        }
Paul's avatar
Paul committed
436
        return ins;
Paul's avatar
Paul committed
437
    }
Paul's avatar
Paul committed
438

Paul's avatar
Paul committed
439
440
441
442
443
444
445
    instruction_ref apply_softmax(instruction_ref ins)
    {
        auto output = insert_allocation(ins, ins->get_shape());
            return prog->replace_instruction(
                ins, miopen_softmax{}, ins->inputs().at(0), output);
    }

Paul's avatar
Paul committed
446
    instruction_ref apply_add(instruction_ref ins)
Paul's avatar
Paul committed
447
    {
Paul's avatar
Paul committed
448
        auto output = insert_allocation(ins, ins->get_shape());
Paul's avatar
Paul committed
449
        return prog->replace_instruction(
Paul's avatar
Paul committed
450
            ins, hip_add{}, ins->inputs().at(0), ins->inputs().at(1), output);
Paul's avatar
Paul committed
451
    }
Paul's avatar
Paul committed
452

Paul's avatar
Paul committed
453
    instruction_ref apply_gemm(instruction_ref ins)
Paul's avatar
Paul committed
454
    {
455
        auto&& op   = any_cast<gemm>(ins->get_operator());
Paul's avatar
Paul committed
456
        auto output = insert_allocation(ins, ins->get_shape());
Paul's avatar
Paul committed
457
        return prog->replace_instruction(
Paul's avatar
Paul committed
458
            ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output);
Paul's avatar
Paul committed
459
    }
460

Paul's avatar
Paul committed
461
    instruction_ref apply_contiguous(instruction_ref ins)
462
    {
463
        auto&& op   = any_cast<contiguous>(ins->get_operator());
Paul's avatar
Paul committed
464
        auto output = insert_allocation(ins, ins->get_shape());
Paul's avatar
Paul committed
465
        return prog->replace_instruction(ins, miopen_contiguous{op}, ins->inputs().at(0), output);
466
    }
467

Paul's avatar
Paul committed
468
    instruction_ref apply_batch_norm_inference(instruction_ref ins)
469
    {
470
        auto&& op       = any_cast<batch_norm_inference>(ins->get_operator());
Paul's avatar
Paul committed
471
        auto output     = insert_allocation(ins, ins->get_shape());
Paul's avatar
Paul committed
472
        shape old_shape = ins->inputs().at(1)->get_shape();
wsttiger's avatar
wsttiger committed
473
        std::vector<int64_t> new_shape{1, static_cast<int64_t>(old_shape.elements()), 1, 1};
Paul's avatar
Paul committed
474
475
        auto reshape_op = reshape{new_shape};
        std::vector<instruction_ref> reshapes;
Paul's avatar
Paul committed
476
477
        std::transform(ins->inputs().begin() + 1,
                       ins->inputs().end(),
Paul's avatar
Paul committed
478
479
                       std::back_inserter(reshapes),
                       [&](auto i) { return prog->insert_instruction(ins, reshape_op, i); });
Paul's avatar
Paul committed
480
        return prog->replace_instruction(ins,
Paul's avatar
Paul committed
481
                                         miopen_batch_norm_inference{op},
Paul's avatar
Paul committed
482
                                         ins->inputs().at(0),
Paul's avatar
Paul committed
483
484
485
486
487
                                         reshapes[0],
                                         reshapes[1],
                                         reshapes[2],
                                         reshapes[3],
                                         output);
488
    }
Paul's avatar
Paul committed
489
490
};

Paul's avatar
Paul committed
491
void lowering::apply(program& p) const { miopen_apply{&p, ctx}.apply(); }
Paul's avatar
Paul committed
492

Paul's avatar
Paul committed
493
} // namespace gpu
Paul's avatar
Paul committed
494

Paul's avatar
Paul committed
495
} // namespace migraph