"test/verify/test_quantizelinear.cpp" did not exist on "4983fecd4178da7e84478b6e8b035554baa0f3c0"
onnx.cpp 35.6 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
Paul's avatar
Paul committed
19
20

namespace migraphx {
Paul's avatar
Paul committed
21
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
22
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
    using node_map      = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
27
28
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
29
30
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
Scott Thornton's avatar
Scott Thornton committed
31
    program prog    = program();
32
    bool is_pytorch = false;
Paul's avatar
Paul committed
33
34
35
36
37

    std::unordered_map<std::string, op_func> ops;

    onnx_parser()
    {
Shucai Xiao's avatar
Shucai Xiao committed
38
        add_generic_op("MatMul", op::dot{});
Khalique's avatar
Khalique committed
39
        add_generic_op("Relu", op::relu{});
Khalique's avatar
Khalique committed
40
41
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Abs", op::abs{});
Shucai Xiao's avatar
Shucai Xiao committed
42
43
        add_generic_op("Exp", op::exp{});
        add_generic_op("Log", op::log{});
Khalique's avatar
Khalique committed
44
45
        // disable dropout for inference
        add_generic_op("Dropout", op::identity{});
Khalique's avatar
Khalique committed
46
        add_generic_op("Identity", op::identity{});
Shucai Xiao's avatar
Shucai Xiao committed
47
48
49
        add_generic_op("Sin", op::sin{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Tan", op::tan{});
50
51
        add_generic_op("Sinh", op::sinh{});
        add_generic_op("Cosh", op::cosh{});
52
        add_generic_op("Tanh", op::tanh{});
53
54
55
        add_generic_op("Asin", op::asin{});
        add_generic_op("Acos", op::acos{});
        add_generic_op("Atan", op::atan{});
Paul's avatar
Paul committed
56

Khalique's avatar
Khalique committed
57
58
59
60
61
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
        add_binary_op("Sub", op::sub{});

Khalique's avatar
Khalique committed
62
63
64
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
65

Khalique's avatar
Khalique committed
66
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
67
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
Khalique's avatar
Khalique committed
68
        add_mem_op("Elu", &onnx_parser::parse_elu);
Paul's avatar
Paul committed
69
70
        add_mem_op("Constant", &onnx_parser::parse_constant);
        add_mem_op("Conv", &onnx_parser::parse_conv);
Paul's avatar
Paul committed
71
72
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
73
74
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
Paul's avatar
Paul committed
75
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
Paul's avatar
Paul committed
76
77
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
78
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
Paul's avatar
Paul committed
79
        add_mem_op("Softmax", &onnx_parser::parse_softmax);
80
81
82
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("Slice", &onnx_parser::parse_slice);
Scott Thornton's avatar
Scott Thornton committed
83
        add_mem_op("Concat", &onnx_parser::parse_concat);
84
85
86
        add_mem_op("Gather", &onnx_parser::parse_gather);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
Khalique's avatar
Khalique committed
87
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
Paul's avatar
Paul committed
88
89
90
91
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
92
93
94
95
96
97
98
99
100
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
101
102
103
104
105
106
107
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
108
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
109
110
111
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
112

113
    template <class T>
Khalique's avatar
Khalique committed
114
    void add_binary_op(std::string name, T x)
115
    {
Paul's avatar
Paul committed
116
        add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
117
            if(args.size() != 2)
Paul's avatar
Paul committed
118
                MIGRAPHX_THROW("binary operators should have 2 operands");
119
            if(contains(attributes, "broadcast") and contains(attributes, "axis"))
120
121
122
123
            {
                uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
                if(broadcasted != 0)
                {
124
                    uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
125
126
127
128
                    auto l =
                        prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
                    return prog.add_instruction(x, args[0], l);
                }
129
                return prog.add_instruction(x, args);
130
            }
Paul's avatar
Paul committed
131
            else
132
            {
Khalique's avatar
Khalique committed
133
                return add_broadcastable_binary_op(args[0], args[1], x);
134
135
136
137
            }
        });
    }

Khalique's avatar
Khalique committed
138
139
140
141
142
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
        if(arg0->get_shape() != arg1->get_shape())
        {
Khalique's avatar
Khalique committed
143
144
145
146
147
148
149
150
151
152
153
154
155
            // Example:
            // s0 = (3,2,4,5) and s1 = (2,1,1)
            //
            // In this case we need to broadcast (:,1,1) portion of
            // s1 plus broadcast the 1st dimension of s1
            // giving output_lens = (3,2,4,5)
            //
            // Another example:
            // s0 = (3,2,1,5) and s1 = (2,7,5)
            // In this case we need to broadcast the (:,:,1:,:) axis
            // of s0 plus the 1st dimension of s1 giving
            // output_lens = (3,2,7,5)
            //
Khalique's avatar
Khalique committed
156
157
158
159
160
161
162
163
            // Get lengths for both arguments
            const std::vector<std::size_t>* s0 = &arg0->get_shape().lens();
            const std::vector<std::size_t>* s1 = &arg1->get_shape().lens();

            // Make sure s0 is the smaller size
            if(s0->size() > s1->size())
                std::swap(s0, s1);

Khalique's avatar
Khalique committed
164
            std::vector<std::size_t> output_lens(*s1);
Khalique's avatar
Khalique committed
165
166
            auto offset = s1->size() - s0->size();
            std::transform(s0->begin(),
Khalique's avatar
Khalique committed
167
168
169
170
                           s0->end(),
                           s1->begin() + offset,
                           output_lens.begin() + offset,
                           [](auto a, auto b) { return std::max(a, b); });
Khalique's avatar
Khalique committed
171
172
173
174
175
176
177
178
179

            auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
            auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
180
181
    }

Paul's avatar
Paul committed
182
    template <class T>
Paul's avatar
Paul committed
183
184
    void add_generic_op(std::string name, T x)
    {
Paul's avatar
Paul committed
185
        add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
186
187
188
189
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
190
    template <class T>
Khalique's avatar
Khalique committed
191
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
192
    {
Paul's avatar
Paul committed
193
        add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
194
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
195
196
197
198
199
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
200
        });
Khalique's avatar
Khalique committed
201
202
    }

Paul's avatar
Paul committed
203
    instruction_ref
Paul's avatar
Paul committed
204
    parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
205
206
    {
        auto dims = args.front()->get_shape().lens();
Scott Thornton's avatar
Scott Thornton committed
207
208
        auto r =
            prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
209
210
        auto s = prog.add_instruction(op::softmax{}, r);
        return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
Paul's avatar
Paul committed
211
212
    }

Paul's avatar
Paul committed
213
    instruction_ref
Paul's avatar
Paul committed
214
    parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
215
    {
216
        op::convolution op;
Paul's avatar
Paul committed
217
218
        if(contains(attributes, "pads"))
        {
Scott Thornton's avatar
Scott Thornton committed
219
            if(contains(attributes, "auto_pad"))
220
            {
Paul's avatar
Paul committed
221
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
222
223
224
            }
            std::vector<std::size_t> padding(4);
            copy(attributes["pads"].ints(), padding.begin());
Scott Thornton's avatar
Scott Thornton committed
225
            if(padding.size() != 4)
226
            {
Paul's avatar
Paul committed
227
                MIGRAPHX_THROW("padding should have 4 values");
228
            }
Scott Thornton's avatar
Scott Thornton committed
229
            if(padding[0] != padding[2] || padding[1] != padding[3])
230
            {
Paul's avatar
Paul committed
231
                MIGRAPHX_THROW("migraphx does not support asymetric padding");
232
233
234
            }
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
Paul's avatar
Paul committed
235
        }
Paul's avatar
Paul committed
236
237
238
239
240
241
242
243
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "dilations"))
        {
            copy(attributes["dilations"].ints(), op.dilation.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
244
        if(contains(attributes, "auto_pad"))
245
246
        {
            auto s = attributes["auto_pad"].s();
Scott Thornton's avatar
Scott Thornton committed
247
            if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
248
            {
Paul's avatar
Paul committed
249
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
250
251
            }

wsttiger's avatar
fixes  
wsttiger committed
252
            if(s.find("SAME") != std::string::npos)
253
254
255
256
            {
                op.padding_mode = op::convolution::same;
            }
        }
Khalique's avatar
Khalique committed
257
258
259
260
        if(contains(attributes, "group"))
        {
            op.group = parse_value(attributes.at("group")).at<int>();
        }
Paul's avatar
Paul committed
261
262
263
264
        if(args.size() == 3)
        {
            uint64_t axis = 1;
            auto l1       = prog.add_instruction(op, args[0], args[1]);
Scott Thornton's avatar
Scott Thornton committed
265
            auto l2       = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]);
266
            return prog.add_instruction(op::add{}, l1, l2);
Paul's avatar
Paul committed
267
        }
Paul's avatar
Paul committed
268
269
        return prog.add_instruction(op, args);
    }
Paul's avatar
Paul committed
270

Paul's avatar
Paul committed
271
272
273
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
                                  std::vector<instruction_ref> args)
Paul's avatar
Paul committed
274
    {
Khalique's avatar
Khalique committed
275
276
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
        if(starts_with(name, "Global"))
277
        {
Khalique's avatar
Khalique committed
278
279
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
280
        }
Paul's avatar
Paul committed
281
282
        if(contains(attributes, "pads"))
        {
283
284
            std::vector<std::size_t> padding(4);
            copy(attributes["pads"].ints(), padding.begin());
Scott Thornton's avatar
Scott Thornton committed
285
            if(padding.size() != 4)
286
            {
Paul's avatar
Paul committed
287
                MIGRAPHX_THROW("padding should have 4 values");
288
            }
Scott Thornton's avatar
Scott Thornton committed
289
            if(padding[0] != padding[2] || padding[1] != padding[3])
290
            {
Paul's avatar
Paul committed
291
                MIGRAPHX_THROW("migraphx does not support asymetric padding");
292
293
294
            }
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
Paul's avatar
Paul committed
295
296
297
298
299
300
301
302
303
        }
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "kernel_shape"))
        {
            copy(attributes["kernel_shape"].ints(), op.lengths.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
304
        if(contains(attributes, "auto_pad"))
305
306
        {
            auto s = attributes["auto_pad"].s();
Scott Thornton's avatar
Scott Thornton committed
307
            if(to_upper(s) != "NOTSET")
308
            {
Paul's avatar
Paul committed
309
                MIGRAPHX_THROW("auto_pad is not supported for pooling");
310
311
312
            }
        }

Paul's avatar
Paul committed
313
        return prog.add_instruction(op, std::move(args));
Paul's avatar
Paul committed
314
315
    }

Paul's avatar
Paul committed
316
    instruction_ref
Paul's avatar
Paul committed
317
    parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
318
    {
319
        op::reshape op;
Paul's avatar
Paul committed
320
321
322
323
324
325
326
        if(args.size() == 1)
        {
            literal s = parse_value(attributes.at("shape"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
327
            literal s = args[1]->get_literal();
Paul's avatar
Paul committed
328
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
329
        }
Paul's avatar
Paul committed
330
331
332
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
333
    instruction_ref
Paul's avatar
Paul committed
334
    parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
335
    {
336
        uint64_t axis = 1;
Paul's avatar
Paul committed
337
338
339
340
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
341
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
342
343
    }

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    instruction_ref
    parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::squeeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

    instruction_ref
    parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::unsqueeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

Scott Thornton's avatar
Scott Thornton committed
362
363
364
365
366
367
368
    instruction_ref
    parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::size_t axis = parse_value(attributes.at("axis")).at<int>();
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
369

370
371
372
373
374
375
376
377
    instruction_ref
    parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::size_t axis = 0;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
378
        op::gather op{axis};
379
380
381
        return prog.add_instruction(op, std::move(args));
    }

382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    instruction_ref
    parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::slice op;
        if(contains(attributes, "axes"))
        {
            literal s = parse_value(attributes.at("axes"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
        {
            literal s = parse_value(attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
        }
        {
            literal s = parse_value(attributes.at("starts"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
402
403
404
    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
                                   const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
405
406
407
408
    {
        literal v = parse_value(attributes.at("value"));
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
409

Paul's avatar
Paul committed
410
    instruction_ref
Paul's avatar
Paul committed
411
    parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
412
413
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
414
        float beta  = 1.0f;
Paul's avatar
Paul committed
415
416
417
418
419
420
421
422
        bool transa = false;
        bool transb = false;
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        if(contains(attributes, "beta"))
        {
423
            beta = parse_value(attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
424
425
426
427
428
429
430
431
432
433
        }
        if(contains(attributes, "transA"))
        {
            transa = parse_value(attributes.at("transA")).at<bool>();
        }
        if(contains(attributes, "transB"))
        {
            transb = parse_value(attributes.at("transB")).at<bool>();
        }
        std::vector<int64_t> perm = {1, 0};
434
435
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
436
437
        if(args.size() == 3)
        {
Khalique's avatar
Khalique committed
438
            if(beta != 0.f)
439
            {
Khalique's avatar
Khalique committed
440
                auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
Khalique's avatar
Khalique committed
441
                auto l4 = args[2];
Khalique's avatar
Khalique committed
442
                if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
Khalique's avatar
Khalique committed
443
                    return l3;
Khalique's avatar
Khalique committed
444
                if(beta != 1.f)
Khalique's avatar
Khalique committed
445
446
                {
                    auto beta_val = prog.add_literal(beta);
Khalique's avatar
Khalique committed
447
448
                    auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
                    l4      = prog.add_instruction(op::mul{}, args[2], l5);
Khalique's avatar
Khalique committed
449
450
                }
                return add_broadcastable_binary_op(l3, l4, op::add{});
451
            }
Paul's avatar
Paul committed
452
        }
Shucai Xiao's avatar
Shucai Xiao committed
453
        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
454
455
    }

456
    instruction_ref
Paul's avatar
Paul committed
457
    parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
458
    {
Scott Thornton's avatar
Scott Thornton committed
459
460
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
461
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
Scott Thornton's avatar
Scott Thornton committed
462
        bool is_test                                      = false;
463
464
465
466
467
468
        if(contains(attributes, "epsilon"))
        {
            epsilon = parse_value(attributes.at("epsilon")).at<float>();
        }
        if(contains(attributes, "momentum"))
        {
469
            momentum = parse_value(attributes.at("momentum")).at<float>();
470
471
472
        }
        if(contains(attributes, "is_test"))
        {
wsttiger's avatar
wsttiger committed
473
            is_test = parse_value(attributes.at("is_test")).at<uint64_t>() > 0;
474
475
476
        }
        if(contains(attributes, "spatial"))
        {
477
            bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
478
479
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
480
        }
Paul's avatar
Paul committed
481
        (void)is_test;
Paul's avatar
Paul committed
482
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
483
        return prog.add_instruction(op, std::move(args));
484
485
    }

486
487
488
489
    instruction_ref parse_leaky_relu(const std::string&,
                                     attribute_map attributes,
                                     std::vector<instruction_ref> args)
    {
Khalique's avatar
Khalique committed
490
        float alpha = 0.01; // default alpha val for leaky relu
491
492
493
494
495
496
497
498
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
499
500
    instruction_ref
    parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
501
502
503
504
505
506
507
508
509
510
    {
        float alpha = 1.0; // default alpha val for elu
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    instruction_ref parse_imagescaler(const std::string&,
                                      attribute_map attributes,
                                      std::vector<instruction_ref> args)
    {
        float scale = 1.0;
        std::vector<float> bias{};
        if(contains(attributes, "scale"))
        {
            scale = parse_value(attributes.at("scale")).at<float>();
        }

        if(contains(attributes, "bias"))
        {
            auto&& bias_floats = attributes["bias"].floats();
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
        auto input_shape = args.front()->get_shape();
Khalique's avatar
Khalique committed
528

Khalique's avatar
Khalique committed
529
530
        auto scale_val = prog.add_literal(scale);
        auto bias_vals = prog.add_literal(
Paul's avatar
Paul committed
531
            migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
532

Paul's avatar
Paul committed
533
534
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_shape}, scale_val);
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Paul's avatar
Paul committed
535
        auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_shape}, bias_vals);
Paul's avatar
Paul committed
536
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
537
    }
Khalique's avatar
Khalique committed
538

Khalique's avatar
Khalique committed
539
540
    instruction_ref
    parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
541
542
543
544
545
546
547
    {
        std::vector<int64_t> perm{};
        if(contains(attributes, "perm"))
        {
            auto&& perm_vals = attributes["perm"].ints();
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
548
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
549
550
    }

551
552
553
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
554
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
555
556
    {
        if(args.size() != 1)
557
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
    instruction_ref parse_constant_fill(const std::string&,
                                        attribute_map attributes,
                                        std::vector<instruction_ref> args)
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

        if(contains(attributes, "dtype"))
        {
            dtype = parse_value(attributes.at("dtype")).at<int>();
        }
        migraphx::shape::type_t type = get_type(dtype);

        if(contains(attributes, "input_as_shape"))
        {
            input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
        }

        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
594
595
        if(contains(attributes, "extra_shape"))
        {
596
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
597
598
        }

599
600
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
601
            if(args.size() != 1)
602
            {
603
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
604
605
            }

Shucai Xiao's avatar
Shucai Xiao committed
606
607
            if(contains(attributes, "shape"))
            {
608
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
609
                               "at the same time");
610
611
            }

612
613
614
            migraphx::argument in = args[0]->eval();
            if(in.empty())
            {
615
                MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
616
            }
617

618
619
620
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
621
622
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
623
624
625
        }
        else if(input_as_shape == 0)
        {
Shucai Xiao's avatar
Shucai Xiao committed
626
627
            if(!contains(attributes, "shape"))
            {
628
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
629
630
631
            }

            literal ls = parse_value(attributes.at("shape"));
632
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
633
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
634
            migraphx::shape s{type, dims};
635
636
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
637
638
639
        }
        else
        {
640
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
641
642
643
        }
    }

Paul's avatar
Paul committed
644
645
646
647
648
649
650
651
652
653
654
655
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
656
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
657
658
659
660
661
662
        }
    }

    void parse_graph(const onnx::GraphProto& graph)
    {
        nodes = get_nodes(graph);
663
664
665
666
667
        std::unordered_map<std::string, onnx::TensorProto> initializer_data;
        for(auto&& f : graph.initializer())
        {
            initializer_data[f.name()] = f;
        }
Paul's avatar
Paul committed
668
669
670
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
671
672
673
674
675
676
677
678
679
680
681
682
            // Does the input have an initializer?
            if(contains(initializer_data, name))
            {
                auto t             = initializer_data[name];
                instructions[name] = prog.add_literal(parse_tensor(t));
            }
            else
            {
                // TODO: Get shape of input parameter
                shape s            = parse_type(input.type());
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
683
684
685
        }
        for(auto&& p : nodes)
        {
Paul's avatar
Paul committed
686
            this->parse_node(p.first);
Paul's avatar
Paul committed
687
688
689
        }
    }

Paul's avatar
Paul committed
690
    void parse_node(const std::string& name)
Paul's avatar
Paul committed
691
    {
Paul's avatar
Paul committed
692
        if(name.empty())
Paul's avatar
Paul committed
693
            MIGRAPHX_THROW("Onnx node must have a name");
Paul's avatar
Paul committed
694
695
696
697
698
699
700
701
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(nodes.count(input) > 0)
                {
Paul's avatar
Paul committed
702
703
704
                    assert(name != input);
                    this->parse_node(input);
                    args.push_back(instructions.at(input));
Paul's avatar
Paul committed
705
706
707
708
709
710
                }
                else
                {
                    args.push_back(instructions.at(input));
                }
            }
Paul's avatar
Paul committed
711
            std::vector<instruction_ref> result;
Paul's avatar
Paul committed
712
713
            if(ops.count(node.op_type()) == 0)
            {
Paul's avatar
Paul committed
714
                result.push_back(prog.add_instruction(unknown{node.op_type()}, args));
Paul's avatar
Paul committed
715
716
717
            }
            else
            {
Paul's avatar
Paul committed
718
                result = ops[node.op_type()](get_attributes(node), args);
Paul's avatar
Paul committed
719
            }
Paul's avatar
Paul committed
720
            // Even no output nodes produce output in migraphx
Paul's avatar
Paul committed
721
            if(node.output().empty() and result.size() == 1)
Paul's avatar
Paul committed
722
723
724
725
726
727
728
729
730
731
732
733
            {
                instructions[name] = result.front();
            }
            else
            {
                assert(node.output().size() >= result.size());
                std::transform(result.begin(),
                               result.end(),
                               node.output().begin(),
                               std::inserter(instructions, instructions.end()),
                               [](auto&& x, auto&& y) { return std::make_pair(y, x); });
            }
Paul's avatar
Paul committed
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        }
    }

    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    static node_map get_nodes(const onnx::GraphProto& graph)
    {
        std::unordered_map<std::string, onnx::NodeProto> result;
Paul's avatar
Paul committed
750
        std::size_t n = 0;
Paul's avatar
Paul committed
751
752
        for(auto&& node : graph.node())
        {
Paul's avatar
Paul committed
753
            if(node.output().empty())
Paul's avatar
Paul committed
754
            {
Paul's avatar
Paul committed
755
                if(node.name().empty())
Paul's avatar
Paul committed
756
757
758
759
760
761
762
763
764
                {
                    result["migraphx_unamed_node_" + std::to_string(n)] = node;
                    n++;
                }
                else
                {
                    result[node.name()] = node;
                }
            }
Paul's avatar
Paul committed
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
            for(auto&& output : node.output())
            {
                result[output] = node;
            }
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::UNDEFINED: return {};
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::STRING: return {};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
        case onnx::AttributeProto::GRAPH: return {};
Paul's avatar
Paul committed
790
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
791
792
793
794
795
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
        case onnx::AttributeProto::STRINGS: return {};
        case onnx::AttributeProto::TENSORS: return {};
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
796
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
797
798
799
800
801
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
Khalique's avatar
Khalique committed
802
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
803
        if(dims.empty())
Khalique's avatar
Khalique committed
804
805
806
        {
            dims = {1};
        }
807
808
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
809
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
810
811
812
813
814
815
816
817
818
819
820
821
            switch(t.data_type())
            {
            case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
            case onnx::TensorProto::FLOAT: return literal{{shape::float_type, dims}, s.data()};
            case onnx::TensorProto::UINT8: throw std::runtime_error("");
            case onnx::TensorProto::INT8: return literal{{shape::int32_type, dims}, s.data()};
            case onnx::TensorProto::UINT16: return literal{{shape::int32_type, dims}, s.data()};
            case onnx::TensorProto::INT16: return literal{{shape::int32_type, dims}, s.data()};
            case onnx::TensorProto::INT32: return literal{{shape::int32_type, dims}, s.data()};
            case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()};
            case onnx::TensorProto::STRING: throw std::runtime_error("");
            case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()};
Paul's avatar
Paul committed
822
            case onnx::TensorProto::FLOAT16: return literal{{shape::half_type, dims}, s.data()};
Scott Thornton's avatar
Scott Thornton committed
823
824
825
826
827
828
            case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()};
            case onnx::TensorProto::UINT32: throw std::runtime_error("");
            case onnx::TensorProto::UINT64: throw std::runtime_error("");
            case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
829
            MIGRAPHX_THROW("Invalid tensor type");
830
        }
Paul's avatar
Paul committed
831
832
833
834
        switch(t.data_type())
        {
        case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
        case onnx::TensorProto::FLOAT:
Paul's avatar
Paul committed
835
            return literal{{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
Paul's avatar
Paul committed
836
837
        case onnx::TensorProto::UINT8: throw std::runtime_error("");
        case onnx::TensorProto::INT8:
Paul's avatar
Paul committed
838
            return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
Paul's avatar
Paul committed
839
        case onnx::TensorProto::UINT16:
Paul's avatar
Paul committed
840
            return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
Paul's avatar
Paul committed
841
        case onnx::TensorProto::INT16:
Paul's avatar
Paul committed
842
            return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
Paul's avatar
Paul committed
843
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
844
            return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
Paul's avatar
Paul committed
845
        case onnx::TensorProto::INT64:
Paul's avatar
Paul committed
846
            return literal{{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
Paul's avatar
Paul committed
847
848
        case onnx::TensorProto::STRING: throw std::runtime_error("");
        case onnx::TensorProto::BOOL:
Paul's avatar
Paul committed
849
            return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
Paul's avatar
Paul committed
850
851
        case onnx::TensorProto::FLOAT16:
            return literal{{shape::half_type, dims}, t.float_data().begin(), t.float_data().end()};
Paul's avatar
Paul committed
852
853
854
855
856
857
858
859
        case onnx::TensorProto::DOUBLE:
            return literal{
                {shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
        case onnx::TensorProto::UINT32: throw std::runtime_error("");
        case onnx::TensorProto::UINT64: throw std::runtime_error("");
        case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
860
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
    }

    static shape parse_type(const onnx::TypeProto& t)
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::UNDEFINED:
            break; // throw std::runtime_error("Unsupported type UNDEFINED");
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::UINT8:
            break; // throw std::runtime_error("Unsupported type UINT8");
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
        case onnx::TensorProto::STRING:
            break; // throw std::runtime_error("Unsupported type STRING");
        case onnx::TensorProto::BOOL:
            break; // throw std::runtime_error("Unsupported type BOOL");
Paul's avatar
Paul committed
882
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
883
884
885
886
887
888
889
890
891
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
        case onnx::TensorProto::COMPLEX64:
            break; // throw std::runtime_error("Unsupported type COMPLEX64");
        case onnx::TensorProto::COMPLEX128:
            break; // throw std::runtime_error("Unsupported type COMPLEX128");
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
892
        auto&& tensor_dims = t.tensor_type().shape().dim();
893
894
895
896
897
898
899
900
901
902
903
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
                       [](auto&& d) -> std::size_t {
                           if(not d.has_dim_value())
                           {
                               long default_batch_size = 1; // FIXME
                               return default_batch_size;
                           }
                           return d.dim_value();
                       });
Paul's avatar
Paul committed
904
905
        return {shape_type, dims};
    }
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Paul's avatar
Paul committed
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
};

program parse_onnx(const std::string& name)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    onnx_parser parser;
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
    return std::move(parser.prog);
}

Paul's avatar
Paul committed
951
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
952
} // namespace migraphx