onnx.cpp 80.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
27
28
29
30
31
32
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
33
    using op_func =
34
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
35
36
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
37
38
39
    program prog            = program();
    bool is_pytorch         = false;
    unsigned int batch_size = 1;
Paul's avatar
Paul committed
40
41

    std::unordered_map<std::string, op_func> ops;
42
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
43
44
45

    onnx_parser()
    {
46
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
47
        add_generic_op("Abs", op::abs{});
48
49
50
51
52
53
54
55
56
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
57
        add_generic_op("Erf", op::erf{});
58
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
59
        add_generic_op("Dropout", op::identity{});
60
61
        add_generic_op("Log", op::log{});
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
62
        add_generic_op("Identity", op::identity{});
63
64
65
66
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
67
        add_generic_op("Sin", op::sin{});
68
        add_generic_op("Sinh", op::sinh{});
69
        add_generic_op("Sqrt", op::sqrt{});
70
71
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
72

Khalique's avatar
Khalique committed
73
74
75
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
76
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
77
        add_binary_op("PRelu", op::prelu{});
78
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
79

Khalique's avatar
Khalique committed
80
81
82
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
83

84
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
85
86
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
87
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
88
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
89
        add_mem_op("Clip", &onnx_parser::parse_clip);
90
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
91
        add_mem_op("Constant", &onnx_parser::parse_constant);
92
93
94
95
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
96
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
97
98
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
99
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
100
        add_mem_op("Gather", &onnx_parser::parse_gather);
Paul's avatar
Paul committed
101
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
102
103
104
105
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
106
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
107
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
108
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
109
110
111
112
        add_mem_op("LRN", &onnx_parser::parse_lrn);
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
Shucai Xiao's avatar
Shucai Xiao committed
113
114
115
116
117
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
118
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
119
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
120
121
122
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
123
124
125
126
127
128
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Pad", &onnx_parser::parse_pad);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
129
        add_mem_op("Split", &onnx_parser::parse_split);
130
131
132
133
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
134
135
136
137
138
139
140

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
141
142
143
144
145
146
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
147
148
149
150
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
151
152
153
154
155
156
157
158
159
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
160
161
162
163
164
165
166
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
167
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
168
169
170
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
171

172
    template <class T>
Khalique's avatar
Khalique committed
173
    void add_binary_op(std::string name, T x)
174
    {
175
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
176
            if(args.size() != 2)
Paul's avatar
Paul committed
177
                MIGRAPHX_THROW("binary operators should have 2 operands");
178
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
179
            {
180
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
181
182
                if(broadcasted != 0)
                {
183
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
184
185
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
186
187
                    return prog.add_instruction(x, args[0], l);
                }
188
                return prog.add_instruction(x, args);
189
            }
Paul's avatar
Paul committed
190
            else
191
            {
Khalique's avatar
Khalique committed
192
                return add_broadcastable_binary_op(args[0], args[1], x);
193
194
195
196
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
197
198
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
199
200
201
202
203
204
205
206
207
208
209
210
211
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
212
        if(s0.size() > s1.size())
213
214
215
216
217
218
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
219
220
221
222
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
223
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
224
                           if(a != b and a != 1 and b != 1)
225
                           {
Shucai Xiao's avatar
Shucai Xiao committed
226
227
228
229
230
231
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
232
233
234
235

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
236
237
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
238
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
239
240
241
242
243
244
245
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
246
247
248
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
249
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
250
251
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
252
253
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
254
            auto out_lens = compute_broadcasted_lens(s0, s1);
255
256
257
258
259
260
261
262
263

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
264
265
266
267
268
269
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
270
271
    }

Paul's avatar
Paul committed
272
    template <class T>
Paul's avatar
Paul committed
273
274
    void add_generic_op(std::string name, T x)
    {
275
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
276
277
278
279
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
280
    template <class T>
Khalique's avatar
Khalique committed
281
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
282
    {
283
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
284
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
285
286
287
288
289
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
290
        });
Khalique's avatar
Khalique committed
291
292
    }

kahmed10's avatar
kahmed10 committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

312
313
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
314
                            const std::vector<int64_t>& padding,
315
316
317
318
319
                            Op& op,
                            float pad_val = 0)
    {
        if(padding[0] != padding[2] || padding[1] != padding[3])
        {
320
321
322
            ins = prog.add_instruction(
                op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
                ins);
323
324
325
326
327
328
329
330
        }
        else
        {
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
        }
    }

331
332
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
333
    {
kahmed10's avatar
kahmed10 committed
334
335
336
337
338
339
340
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
341
        {
kahmed10's avatar
kahmed10 committed
342
343
344
345
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
346
        }
kahmed10's avatar
kahmed10 committed
347
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
348
        {
kahmed10's avatar
kahmed10 committed
349
350
351
352
353
354
355
356
357
358
359
360
361
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
362
        }
kahmed10's avatar
kahmed10 committed
363
364
365
366
367
368
369
370
371
372
373
374
375

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
376
377
    }

Shucai Xiao's avatar
Shucai Xiao committed
378
    template <class Op>
379
380
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
381
    {
382
        int64_t axis = 1;
383
        if(contains(info.attributes, "axis"))
384
        {
385
            axis = parse_value(info.attributes.at("axis")).at<int>();
386
387
        }

388
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
389
390
    }

Shucai Xiao's avatar
Shucai Xiao committed
391
    template <class Op>
392
393
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
394
    {
395
        int64_t axis = 0;
396
        if(contains(info.attributes, "axis"))
397
        {
398
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
399
400
        }

Shucai Xiao's avatar
Shucai Xiao committed
401
        int keep_dims = 1;
402
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
403
        {
404
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
405
406
        }

Shucai Xiao's avatar
Shucai Xiao committed
407
        if(keep_dims == 0)
408
        {
409
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
410
            return prog.add_instruction(op::squeeze{{axis}}, ins);
411
412
413
        }
        else
        {
414
            return prog.add_instruction(Op{axis}, std::move(args));
415
        }
416
417
    }

418
419
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
420
                                               node_info info,
421
                                               Op& op,
422
423
424
425
                                               std::array<std::size_t, 2> k_lens,
                                               std::array<std::size_t, 2> dilation,
                                               const std::vector<std::size_t>& in_lens,
                                               float value = 0.0f)
426
    {
427
        if(!contains(info.attributes, "auto_pad"))
428
429
430
431
        {
            return ins;
        }

432
        auto auto_pad = info.attributes["auto_pad"].s();
433
434
        if(auto_pad.find("SAME") != std::string::npos)
        {
435
436
437
438
439
440
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            std::vector<int64_t> padding(in_lens.size());
            calculate_padding(
                0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
            calculate_padding(
                1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
441

442
            check_asym_padding(ins, padding, op, value);
443
444
445
446
447
        }

        return ins;
    }

448
    template <class Op>
Paul's avatar
Paul committed
449
    instruction_ref
450
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
451
    {
452
        Op op;
453
454
        auto l0      = args[0];
        auto weights = args[1];
455
        std::vector<int64_t> padding;
456
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
457
        {
458
            if(contains(info.attributes, "auto_pad"))
459
            {
460
461
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
462
                {
463
464
                    MIGRAPHX_THROW(
                        "PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
465
                }
466
            }
467
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
468
            if(padding.size() != 4)
469
            {
470
                MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
471
            }
472
            check_asym_padding(l0, padding, op);
Paul's avatar
Paul committed
473
        }
474
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
475
        {
476
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
477
        }
478
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
479
        {
480
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
481
        }
482
        if(contains(info.attributes, "auto_pad"))
483
        {
484
            auto s = info.attributes["auto_pad"].s();
wsttiger's avatar
fixes  
wsttiger committed
485
            if(s.find("SAME") != std::string::npos)
486
            {
487
488
489
490
491
492
                op.padding_mode                 = op::padding_mode_t::same;
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
493
                padding.resize(input_dims.size());
494
495
496
497
498
499
                calculate_padding(
                    0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(
                    1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                check_asym_padding(l0, padding, op);
500
            }
501
502
503
504
505

            auto in_lens                      = args[0]->get_shape().lens();
            auto weight_lens                  = args[1]->get_shape().lens();
            std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
            l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
506
        }
507
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
508
        {
509
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
510
        }
kahmed10's avatar
kahmed10 committed
511
512
513
514
515

        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

516
517
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
518
519
520
521
522
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
        bool asymm_padding = false;
523
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
524
        {
525
            if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
526
            {
527
528
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
529
530
531
532
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
            }
533
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
            if(padding.size() != 4)
            {
                MIGRAPHX_THROW("padding should have 4 values");
            }
            if(padding[0] != padding[2] || padding[1] != padding[3])
            {
                asymm_padding = true;
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
548
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
549
        {
550
            copy(info.attributes["strides"].ints(), op.stride.begin());
kahmed10's avatar
kahmed10 committed
551
        }
552
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
553
        {
554
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
555
        }
556
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
557
        {
558
559
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
560
561
562
563
564
565
566
567
568
569
            {
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

570
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
571
        {
572
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        }

        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
        std::vector<int64_t> curr_shape{dims[2], dims[3]};
        if(asymm_padding)
        {
            op::slice slice_op;
            slice_op.axes   = {0, 1, 2, 3};
            slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
            slice_op.ends   = {
                dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};

            l1 = prog.add_instruction(slice_op, l1);
        }

589
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
590
591
        {
            std::vector<int64_t> output_padding;
592
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
593
594
595
596
            output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
            l1             = prog.add_instruction(op::pad{output_padding}, l1);
        }

597
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
598
599
        {
            std::vector<int64_t> output_shape;
600
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
            dims       = to_int64_vector(l1->get_shape().lens());
            curr_shape = {dims[2], dims[3]};
            if(curr_shape != output_shape)
            {
                std::vector<int64_t> target_padding = {0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       output_shape[0] - curr_shape[0],
                                                       output_shape[1] - curr_shape[1]};
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
618
    }
Paul's avatar
Paul committed
619

620
621
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
622
    {
Khalique's avatar
Khalique committed
623
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
624
        auto l0 = args[0];
Khalique's avatar
Khalique committed
625
        if(starts_with(name, "Global"))
626
        {
Khalique's avatar
Khalique committed
627
628
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
629
        }
630

631
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
632
        {
633
            if(contains(info.attributes, "auto_pad"))
634
            {
635
                auto s = info.attributes["auto_pad"].s();
636
637
638
639
640
641
642
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

643
            std::vector<std::int64_t> padding;
644
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
645
            if(padding.size() != 4)
646
            {
647
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
648
            }
649
650
651
652
            float pad_val = 0;
            if(op.mode == "max")
                pad_val = std::numeric_limits<float>::lowest();
            check_asym_padding(l0, padding, op, pad_val);
Paul's avatar
Paul committed
653
        }
654

655
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
656
        {
657
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
658
        }
659
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
660
        {
661
            copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
Paul's avatar
Paul committed
662
        }
663

664
        if(contains(info.attributes, "auto_pad"))
665
        {
666
667
668
669
670
671
            auto s = info.attributes["auto_pad"].s();
            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }

672
            auto in_lens = args[0]->get_shape().lens();
673
674
675
676
677
678
679
680
            float val    = 0.0f;
            // MaxPool
            if(op.mode == "max")
            {
                val = std::numeric_limits<float>::lowest();
            }

            l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
681
682
        }

683
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
684
685
    }

Paul's avatar
Paul committed
686
    instruction_ref
687
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
688
    {
689
        op::reshape op;
Paul's avatar
Paul committed
690
691
        if(args.size() == 1)
        {
692
            literal s = parse_value(info.attributes.at("shape"));
693
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
694
695
696
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
697
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
698
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
699
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
700
        }
701

Shucai Xiao's avatar
Shucai Xiao committed
702
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
703
704
    }

Paul's avatar
Paul committed
705
    instruction_ref
706
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
707
    {
708
        int64_t axis = 1;
709
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
710
        {
711
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
712
        }
713
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
714
715
    }

716
    instruction_ref
717
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
718
719
    {
        op::squeeze op;
720
        literal s = parse_value(info.attributes.at("axes"));
721
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
722
        return prog.add_instruction(op, make_contiguous(args[0]));
723
724
725
    }

    instruction_ref
726
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
727
728
    {
        op::unsqueeze op;
729
        literal s = parse_value(info.attributes.at("axes"));
730
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
731
        return prog.add_instruction(op, make_contiguous(args[0]));
732
733
    }

Scott Thornton's avatar
Scott Thornton committed
734
    instruction_ref
735
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
736
    {
Shucai Xiao's avatar
Shucai Xiao committed
737
        // change to hande axis to be negative values
738
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
739
740
741
742
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

743
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
744
745
746
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
747

748
    instruction_ref
749
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
750
    {
751
        int axis = 0;
752
        if(contains(info.attributes, "axis"))
753
        {
754
            axis = parse_value(info.attributes.at("axis")).at<int>();
755
        }
756

757
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
758
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
759
760
    }

761
    instruction_ref
762
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
763
764
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
787
        {
788
            literal s = parse_value(info.attributes.at("axes"));
789
790
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
791
792

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
793
        {
Shucai Xiao's avatar
Shucai Xiao committed
794
795
796
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
797
        }
Shucai Xiao's avatar
Shucai Xiao committed
798
        else if(contains(info.attributes, "ends"))
799
        {
800
801
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
802
        }
Shucai Xiao's avatar
Shucai Xiao committed
803
804
805
806
807
808
809
810

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
811
        {
812
            literal s = parse_value(info.attributes.at("starts"));
813
814
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
815

816
817
818
        return prog.add_instruction(op, args[0]);
    }

819
820
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
821
    {
822
        literal v = parse_value(info.attributes.at("value"));
823
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
824
        if(v.get_shape().elements() == 0)
825
826
827
828
        {
            return prog.add_literal(literal{});
        }

829
        auto dim_size = info.attributes.at("value").t().dims_size();
830
831
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
832
        {
833
            migraphx::shape scalar_shape{v.get_shape().type()};
834
835
836
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
837
838
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
839

Paul's avatar
Paul committed
840
    instruction_ref
841
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
842
843
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
844
        float beta  = 1.0f;
Paul's avatar
Paul committed
845
846
        bool transa = false;
        bool transb = false;
847
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
848
        {
849
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
850
        }
851
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
852
        {
853
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
854
        }
855
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
856
        {
857
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
858
        }
859
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
860
        {
861
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
862
        }
863
864
865
866
867
868

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

869
870
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
871
872
        if(args.size() == 3)
        {
873
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
874
            {
Shucai Xiao's avatar
Shucai Xiao committed
875
                auto out_lens   = l1->get_shape().lens();
876
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
877
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
878
879
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
880
                {
881
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
882
                }
883
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
884
            }
Paul's avatar
Paul committed
885
        }
886
887

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
888
889
    }

890
    template <class Op>
891
    instruction_ref
892
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
893
    {
Shucai Xiao's avatar
Shucai Xiao committed
894
895
        auto l0      = args[0];
        auto l1      = args[1];
896
897
898
899
900
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
901
        if(l0_lens.size() == 1)
902
903
904
905
906
907
908
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
909
        if(l1_lens.size() == 1)
910
911
912
913
914
915
916
917
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
918
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
919
920
921
922
923
924
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
925
            l0_broadcasted_lens = output_lens;
926
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
927
            l1_broadcasted_lens = output_lens;
928
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
929
            if(l0_lens != l0_broadcasted_lens)
930
931
932
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
933
            if(l1_lens != l1_broadcasted_lens)
934
935
936
937
938
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

939
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
940
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
941
        if(is_a_prepended)
942
943
944
945
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
946
        if(is_b_appended)
947
948
949
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
950

951
952
953
        return dot_res;
    }

954
    instruction_ref
955
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
956
    {
Scott Thornton's avatar
Scott Thornton committed
957
958
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
959
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
960
        if(contains(info.attributes, "epsilon"))
961
        {
962
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
963
        }
964
        if(contains(info.attributes, "momentum"))
965
        {
966
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
967
        }
968
        if(contains(info.attributes, "spatial"))
969
        {
970
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
971
972
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
973
        }
Paul's avatar
Paul committed
974
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
975
        return prog.add_instruction(op, std::move(args));
976
977
    }

978
979
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
980
981
982
983
984
985
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
986
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
987
        {
988
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1013
1014
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1015
    {
Khalique's avatar
Khalique committed
1016
        float alpha = 0.01; // default alpha val for leaky relu
1017
        if(contains(info.attributes, "alpha"))
1018
        {
1019
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1020
1021
1022
1023
1024
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1025
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1026
1027
    {
        float alpha = 1.0; // default alpha val for elu
1028
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1029
        {
1030
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1031
1032
1033
1034
1035
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1036
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1037
1038
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1039
1040
1041
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1042
1043
1044
1045
1046
1047
1048
1049
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1050
1051
1052
1053
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1054
1055
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1056
1057
1058
    {
        float scale = 1.0;
        std::vector<float> bias{};
1059
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1060
        {
1061
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1062
1063
        }

1064
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1065
        {
1066
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1067
1068
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1069
1070
1071
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1072

Shucai Xiao's avatar
Shucai Xiao committed
1073
1074
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1075

1076
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1077
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1078
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1079
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1080
    }
Khalique's avatar
Khalique committed
1081

Khalique's avatar
Khalique committed
1082
    instruction_ref
1083
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1084
1085
    {
        std::vector<int64_t> perm{};
1086
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1087
        {
1088
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1089
1090
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1091
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1092
1093
    }

1094
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1095
1096
1097
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
1098
        if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1099
        {
1100
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1101
1102
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1103
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1104
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1105
1106
1107
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1108
        if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1109
        {
1110
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1111
        }
1112
        if(contains(info.attributes, "mode"))
Khalique's avatar
Khalique committed
1113
        {
1114
            auto mode = info.attributes.at("mode").s();
Khalique's avatar
Khalique committed
1115
1116
1117
1118
1119
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1120
1121
1122
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1123
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1124
1125
    {
        if(args.size() != 1)
1126
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1139
1140
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1141
1142
1143
1144
1145
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1146
        if(contains(info.attributes, "dtype"))
1147
        {
1148
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1149
        }
Shucai Xiao's avatar
Shucai Xiao committed
1150
        shape::type_t type = get_type(dtype);
1151

1152
        if(contains(info.attributes, "input_as_shape"))
1153
        {
1154
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1155
1156
        }

1157
        if(contains(info.attributes, "value"))
1158
        {
1159
            value = parse_value(info.attributes.at("value")).at<float>();
1160
1161
        }

1162
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1163
        {
1164
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1165
1166
        }

1167
1168
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1169
            if(args.size() != 1)
1170
            {
1171
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1172
1173
            }

1174
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1175
            {
1176
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1177
                               "at the same time");
1178
1179
            }

1180
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1181
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1182

1183
1184
1185
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1186
1187
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1188
1189
1190
        }
        else if(input_as_shape == 0)
        {
1191
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1192
            {
1193
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1194
1195
            }

1196
            literal ls = parse_value(info.attributes.at("shape"));
1197
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1198
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1199
            migraphx::shape s{type, dims};
1200
1201
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1202
1203
1204
        }
        else
        {
1205
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1206
1207
1208
        }
    }

1209
1210
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1211
1212
    {
        literal l_val{};
1213
        if(contains(info.attributes, "value"))
1214
        {
1215
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1216
            if(l_val.get_shape().elements() != 1)
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1228

Shucai Xiao's avatar
Shucai Xiao committed
1229
        if(args.empty())
1230
        {
Shucai Xiao's avatar
Shucai Xiao committed
1231
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1232
1233
1234
        }
        else
        {
1235
1236
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1237
            if(args[0]->get_shape().elements() == 0)
1238
            {
1239
                s = migraphx::shape{type, {1}, {0}};
1240
            }
1241
1242
1243
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1244
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1245

1246
1247
1248
1249
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1250

Shucai Xiao's avatar
Shucai Xiao committed
1251
            literal l_out{};
1252
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1253
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1254
                // l_val contains only one element
1255
                std::vector<val_type> out_vec(s.elements(), val.front());
1256
1257
1258
1259
1260
1261
1262
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1263
    instruction_ref
1264
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1265
    {
Shucai Xiao's avatar
Shucai Xiao committed
1266
        auto in_lens             = args[0]->get_shape().lens();
1267
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1268
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1269
1270
1271
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1272
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1273
1274
    }

Shucai Xiao's avatar
Shucai Xiao committed
1275
    std::vector<instruction_ref>
1276
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1277
1278
    {
        migraphx::shape input_shape = args[0]->get_shape();
1279
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1280

1281
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1282
        {
1283
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1284
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1285
1286
1287
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1288
1289
1290
1291
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1292
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1293
        {
1294
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1295
1296
        }

1297
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1298
1299
        if(direction == "bidirectional")
        {
1300
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1301
1302
1303
        }
        else if(direction == "reverse")
        {
1304
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1305
1306
        }

1307
        std::vector<std::string> vec_names{"tanh"};
1308
        if(contains(info.attributes, "activations"))
1309
        {
1310
            auto names = info.attributes.at("activations").strings();
1311
            vec_names.clear();
1312
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1313
1314
1315
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1316
1317
        }

1318
1319
1320
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1321
        if(name_it != vec_names.end())
1322
1323
1324
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1325

Shucai Xiao's avatar
Shucai Xiao committed
1326
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1327
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1328
        // if only one actv function is provided, we use it in both
1329
        // forward and reverse direction
1330
        if(dirct == op::rnn_direction::bidirectional)
1331
        {
Shucai Xiao's avatar
Shucai Xiao committed
1332
            if(vec_names.size() == 1)
1333
1334
1335
1336
1337
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1338
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1339
1340
1341
1342
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1343

Shucai Xiao's avatar
Shucai Xiao committed
1344
1345
        // To be added later
        float clip = 0.0;
1346
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1347
        {
1348
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1349
1350
        }

1351
1352
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1353
        if(args.size() < 6)
1354
1355
1356
1357
1358
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1359
1360
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1361
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1362

1363
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1364
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1365

Shucai Xiao's avatar
Shucai Xiao committed
1366
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1367
1368
    }

1369
    std::vector<instruction_ref>
1370
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1371
1372
1373
1374
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1375
        if(contains(info.attributes, "hidden_size"))
1376
        {
1377
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1378
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1379
1380
1381
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1382
1383
1384
1385
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1386
        if(contains(info.attributes, "direction"))
1387
        {
1388
            direction = info.attributes.at("direction").s();
1389
1390
        }

1391
        op::rnn_direction dirct = op::rnn_direction::forward;
1392
1393
        if(direction == "bidirectional")
        {
1394
            dirct = op::rnn_direction::bidirectional;
1395
1396
1397
        }
        else if(direction == "reverse")
        {
1398
            dirct = op::rnn_direction::reverse;
1399
1400
        }

1401
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1402
        if(contains(info.attributes, "activations"))
1403
        {
1404
            auto names = info.attributes.at("activations").strings();
1405
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1406
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1407
1408
1409
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1410
1411
        }

1412
        // need 4 activation functions
1413
        if(dirct == op::rnn_direction::bidirectional)
1414
        {
Shucai Xiao's avatar
Shucai Xiao committed
1415
            // 4 activation functions are used in the bidirectional
1416
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1417
1418
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1419
1420
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1421
1422
1423
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1424
            if(vec_names.size() == 1)
1425
            {
1426
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1427
            }
1428
            else if(vec_names.size() == 2)
1429
            {
1430
1431
1432
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1433
            }
1434
            else if(vec_names.size() == 3)
1435
            {
1436
                vec_names.push_back(vec_names.at(2));
1437
1438
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1439
        else
1440
        {
1441
            if(vec_names.size() == 1)
1442
            {
1443
                vec_names.push_back(vec_names.at(0));
1444
1445
1446
            }
        }

1447
1448
1449
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1450
        if(name_it != vec_names.end())
1451
1452
1453
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1454

Shucai Xiao's avatar
Shucai Xiao committed
1455
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1456
1457
1458
1459
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1460
1461

        float clip = 0.0;
1462
        if(contains(info.attributes, "clip"))
1463
        {
1464
            clip = parse_value(info.attributes.at("clip")).at<float>();
1465
1466
1467
        }

        int linear_before_reset = 0;
1468
        if(contains(info.attributes, "linear_before_reset"))
1469
        {
1470
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1471
1472
        }

Shucai Xiao's avatar
Shucai Xiao committed
1473
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1474
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1475
1476
1477
1478
1479
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1480
1481
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1482
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1483
            std::move(args));
1484
1485

        // second output for last gru output
1486
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1487

Shucai Xiao's avatar
Shucai Xiao committed
1488
        return {hidden_states, last_output};
1489
1490
    }

Shucai Xiao's avatar
Shucai Xiao committed
1491
    std::vector<instruction_ref>
1492
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1493
1494
1495
1496
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1497
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1498
        {
1499
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1500
1501
1502
1503
1504
1505
1506
1507
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1508
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1509
        {
1510
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1511
1512
        }

Shucai Xiao's avatar
Shucai Xiao committed
1513
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1514
1515
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1516
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1517
1518
1519
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1520
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1521
        }
Shucai Xiao's avatar
Shucai Xiao committed
1522
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1523
        {
Shucai Xiao's avatar
Shucai Xiao committed
1524
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1525
1526
1527
1528
1529
1530
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1531
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1532
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1533
        {
1534
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1535
1536
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1537
1538
1539
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1540
1541
1542
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1543
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1544
1545
1546
1547
1548
1549
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1550
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1551
1552
1553
1554
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1555
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1556
1557
1558
1559
1560
1561
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1562
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1563
1564
1565

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1566
1567
1568
1569
1570
1571
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1572
1573
1574
1575
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1576
1577
1578
1579
1580
1581
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1582
1583
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1584
1585
1586
1587
1588
1589
1590
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1591
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1592

Shucai Xiao's avatar
Shucai Xiao committed
1593
1594
1595
1596
1597
1598
1599
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1600
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1601

Shucai Xiao's avatar
Shucai Xiao committed
1602
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1603
1604
1605
1606
1607
1608
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1609
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1610
1611
1612

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1613
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1614
1615
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1616
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1617
1618
1619
            }
        }

1620
1621
1622
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1623
        if(name_it != vec_names.end())
1624
1625
1626
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1627
1628

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1629
1630
1631
1632
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1633
1634

        float clip = 0.0;
1635
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1636
        {
1637
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1638
1639
1640
        }

        int input_forget = 0;
1641
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
1642
        {
1643
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1644
1645
1646
1647
1648
1649
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1650
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1651
1652
1653
1654
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1655
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1656
1657

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1658
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1659
1660
1661
1662
1663
1664

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1665

Shucai Xiao's avatar
Shucai Xiao committed
1666
    template <class T>
1667
1668
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1669
1670
1671
1672
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1673
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1674
        std::iota(axes.begin(), axes.end(), 0);
1675
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
1676
1677
        {
            axes.clear();
1678
            auto&& attr_axes = info.attributes["axes"].ints();
1679
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1680
1681
1682
        }

        int keep_dims = 1;
1683
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
1684
        {
1685
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1686
1687
1688
1689
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1690
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1691
1692
1693
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1694
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1695
            return prog.add_instruction(op::squeeze{axes}, ins);
1696
1697
        }
    }
1698

Shucai Xiao's avatar
Shucai Xiao committed
1699
    instruction_ref
1700
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1701
1702
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
1703
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1704
1705
1706
    }

    instruction_ref
1707
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1708
1709
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1710
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1711
1712
1713
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

1714
1715
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1716
    {
1717
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1718
1719
1720
        return prog.add_instruction(op::log{}, sum_ins);
    }

1721
1722
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1723
1724
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
1725
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1726
1727
1728
        return prog.add_instruction(op::log{}, sum_ins);
    }

1729
1730
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1731
1732
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1733
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1734
1735
    }

Shucai Xiao's avatar
Shucai Xiao committed
1736
    instruction_ref
1737
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
1738
    {
1739
        if(!contains(info.attributes, "to"))
1740
1741
1742
1743
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

1744
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
1745
1746
1747
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1748

1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

Paul's avatar
Paul committed
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1814
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1815
1816
1817
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
1834
1835
1836
    void parse_graph(const onnx::GraphProto& graph)
    {
        nodes = get_nodes(graph);
1837
        for(auto&& f : graph.initializer())
1838
1839
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1840
1841
1842
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1843
1844
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1845
1846
            {
                // TODO: Get shape of input parameter
1847
                shape s            = parse_type(input.type(), batch_size);
1848
1849
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1850
        }
Paul's avatar
Paul committed
1851
        for(auto&& output : graph.output())
Paul's avatar
Paul committed
1852
        {
Paul's avatar
Paul committed
1853
            this->parse_node(output.name());
Paul's avatar
Paul committed
1854
        }
Shucai Xiao's avatar
Shucai Xiao committed
1855

1856
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
1857
        auto prog_output = graph.output();
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
1878
1879
    }

Shucai Xiao's avatar
Shucai Xiao committed
1880
    void parse_undefined(const std::string& name)
1881
    {
Shucai Xiao's avatar
Shucai Xiao committed
1882
        auto ins           = prog.add_instruction(op::undefined{});
1883
1884
1885
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1886
    void parse_node(const std::string& name)
Paul's avatar
Paul committed
1887
    {
Paul's avatar
Paul committed
1888
        if(name.empty())
Paul's avatar
Paul committed
1889
            MIGRAPHX_THROW("Onnx node must have a name");
Paul's avatar
Paul committed
1890
1891
1892
1893
1894
1895
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1896
                if(input.empty())
Paul's avatar
Paul committed
1897
                {
Shucai Xiao's avatar
Shucai Xiao committed
1898
                    this->parse_undefined(input);
Paul's avatar
Paul committed
1899
                }
Shucai Xiao's avatar
Shucai Xiao committed
1900
                else if(nodes.count(input) > 0)
Paul's avatar
Paul committed
1901
                {
Shucai Xiao's avatar
Shucai Xiao committed
1902
1903
                    assert(name != input);
                    this->parse_node(input);
Paul's avatar
Paul committed
1904
                }
1905
                args.push_back(instructions.at(input));
Paul's avatar
Paul committed
1906
            }
Paul's avatar
Paul committed
1907
            std::vector<instruction_ref> result;
Paul's avatar
Paul committed
1908
1909
            if(ops.count(node.op_type()) == 0)
            {
1910
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
Paul's avatar
Paul committed
1911
1912
1913
            }
            else
            {
1914
1915
                std::size_t output_num = static_cast<std::size_t>(node.output().size());
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
Paul's avatar
Paul committed
1916
            }
Paul's avatar
Paul committed
1917
            // Even no output nodes produce output in migraphx
Paul's avatar
Paul committed
1918
            if(node.output().empty() and result.size() == 1)
Paul's avatar
Paul committed
1919
1920
            {
                instructions[name] = result.front();
Paul's avatar
Paul committed
1921
1922
1923
            }
            else
            {
1924
                auto output_num = std::min<std::size_t>(node.output().size(), result.size());
Shucai Xiao's avatar
Shucai Xiao committed
1925
                std::transform(node.output().begin(),
1926
                               node.output().begin() + output_num,
Shucai Xiao's avatar
Shucai Xiao committed
1927
                               result.begin(),
Paul's avatar
Paul committed
1928
                               std::inserter(instructions, instructions.end()),
Shucai Xiao's avatar
Shucai Xiao committed
1929
                               [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
            }
        }
    }

    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    static node_map get_nodes(const onnx::GraphProto& graph)
    {
        std::unordered_map<std::string, onnx::NodeProto> result;
Paul's avatar
Paul committed
1947
        std::size_t n = 0;
Paul's avatar
Paul committed
1948
1949
        for(auto&& node : graph.node())
        {
Paul's avatar
Paul committed
1950
            if(node.output().empty())
Paul's avatar
Paul committed
1951
            {
Paul's avatar
Paul committed
1952
                if(node.name().empty())
Paul's avatar
Paul committed
1953
1954
1955
1956
1957
1958
1959
1960
1961
                {
                    result["migraphx_unamed_node_" + std::to_string(n)] = node;
                    n++;
                }
                else
                {
                    result[node.name()] = node;
                }
            }
Paul's avatar
Paul committed
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
            for(auto&& output : node.output())
            {
                result[output] = node;
            }
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
1984
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1985
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
1986
1987
1988
1989
1990
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
1991
1992
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
1993
1994
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
1995
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
1996
1997
1998
1999
2000
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
2001
2002
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
2003
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
2004
2005
            switch(t.data_type())
            {
2006
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
2007
2008
2009
2010
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
2011
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
2012
2013
2014
2015
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
2016
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
2017
2018
2019
2020
2021
2022
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
2023
2024
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
2025
            MIGRAPHX_THROW("Invalid tensor type");
2026
        }
Paul's avatar
Paul committed
2027
2028
2029
2030
2031
2032
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
2033
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
2034
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
2035
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2036
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
2037
2038
2039
2040
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
2041
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2042
        {
Khalique's avatar
Khalique committed
2043
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2044
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2045
2046
2047
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2048
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2049
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2050
        }
Paul's avatar
Paul committed
2051
2052
2053
2054
2055
2056
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2057
2058
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2059
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2060
2061
    }

Khalique's avatar
Khalique committed
2062
    static literal
2063
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2064
    {
Khalique's avatar
Khalique committed
2065
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2066
        if(dims.empty())
2067
            return literal{{shape_type}, data};
2068
2069
2070
        return literal{{shape_type, dims}, data};
    }

2071
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2072
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2073
2074
    {
        if(dims.empty())
2075
            return literal{{shape_type}, data.begin(), data.end()};
2076
        return literal{{shape_type, dims}, data.begin(), data.end()};
2077
2078
    }

2079
    static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
Paul's avatar
Paul committed
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2090
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2091
2092
2093
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
2094
2095
2096
2097
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2098
2099
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2100
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2101
2102
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2103
        auto&& tensor_dims = t.tensor_type().shape().dim();
2104
2105
2106
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2107
2108
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2109
                           {
2110
2111
2112
                               if(static_cast<int>(d.dim_value()) <= 0)
                                   return batch_size;
                               return d.dim_value();
2113
                           }
2114
                           return batch_size;
2115
                       });
2116
2117
2118
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2119
2120
        return {shape_type, dims};
    }
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2143
2144
2145

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2146
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2147
2148
2149
2150
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2151
2152
};

Paul Fultz II's avatar
Paul Fultz II committed
2153
2154
template <class... Ts>
program parse_onnx_from(onnx_options options, Ts&&... xs)
Paul's avatar
Paul committed
2155
2156
{
    onnx_parser parser;
2157
    parser.batch_size = options.batch_size;
Paul's avatar
Paul committed
2158
2159
2160
2161
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
Paul Fultz II's avatar
Paul Fultz II committed
2162
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2163
2164
2165
2166
2167
2168
2169
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
Paul Fultz II's avatar
Paul Fultz II committed
2170
    parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2171
2172
2173
2174
#endif
    return std::move(parser.prog);
}

Paul Fultz II's avatar
Paul Fultz II committed
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
program parse_onnx(const std::string& name, onnx_options options)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

program parse_onnx_buffer(const std::string& buffer, onnx_options options)
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options)
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2191
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2192
} // namespace migraphx