onnx.cpp 79.3 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
27
28
29
30
31
32
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
33
    using op_func =
34
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
35
36
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
37
38
39
    program prog            = program();
    bool is_pytorch         = false;
    unsigned int batch_size = 1;
Paul's avatar
Paul committed
40
41

    std::unordered_map<std::string, op_func> ops;
42
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
43
44
45

    onnx_parser()
    {
46
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
47
        add_generic_op("Abs", op::abs{});
48
49
50
51
52
53
54
55
56
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
57
        add_generic_op("Erf", op::erf{});
58
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
59
        add_generic_op("Dropout", op::identity{});
60
61
        add_generic_op("Log", op::log{});
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
62
        add_generic_op("Identity", op::identity{});
kahmed10's avatar
kahmed10 committed
63
        add_generic_op("Reciprocal", op::recip{});
64
65
66
67
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
68
        add_generic_op("Sin", op::sin{});
69
        add_generic_op("Sinh", op::sinh{});
70
        add_generic_op("Sqrt", op::sqrt{});
71
72
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
73

Khalique's avatar
Khalique committed
74
75
76
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
77
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
78
        add_binary_op("PRelu", op::prelu{});
79
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
80

Khalique's avatar
Khalique committed
81
82
83
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
84

85
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
86
87
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
88
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
89
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
90
        add_mem_op("Clip", &onnx_parser::parse_clip);
91
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
92
        add_mem_op("Constant", &onnx_parser::parse_constant);
93
94
95
96
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
97
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
98
99
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
100
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
101
        add_mem_op("Gather", &onnx_parser::parse_gather);
Paul's avatar
Paul committed
102
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
103
104
105
106
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
107
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
108
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
109
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
110
111
112
113
        add_mem_op("LRN", &onnx_parser::parse_lrn);
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
Shucai Xiao's avatar
Shucai Xiao committed
114
115
116
117
118
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
119
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
120
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
121
122
123
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
124
125
126
127
128
129
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Pad", &onnx_parser::parse_pad);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
130
        add_mem_op("Split", &onnx_parser::parse_split);
131
132
133
134
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
135
136
137
138
139
140
141

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
142
143
144
145
146
147
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
148
149
150
151
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
152
153
154
155
156
157
158
159
160
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
161
162
163
164
165
166
167
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
168
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
169
170
171
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
172

173
    template <class T>
Khalique's avatar
Khalique committed
174
    void add_binary_op(std::string name, T x)
175
    {
176
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
177
            if(args.size() != 2)
Paul's avatar
Paul committed
178
                MIGRAPHX_THROW("binary operators should have 2 operands");
179
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
180
            {
181
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
182
183
                if(broadcasted != 0)
                {
184
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
185
186
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
187
188
                    return prog.add_instruction(x, args[0], l);
                }
189
                return prog.add_instruction(x, args);
190
            }
Paul's avatar
Paul committed
191
            else
192
            {
Khalique's avatar
Khalique committed
193
                return add_broadcastable_binary_op(args[0], args[1], x);
194
195
196
197
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
198
199
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
200
201
202
203
204
205
206
207
208
209
210
211
212
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
213
        if(s0.size() > s1.size())
214
215
216
217
218
219
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
220
221
222
223
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
224
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
225
                           if(a != b and a != 1 and b != 1)
226
                           {
Shucai Xiao's avatar
Shucai Xiao committed
227
228
229
230
231
232
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
233
234
235
236

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
237
238
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
239
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
240
241
242
243
244
245
246
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
247
248
249
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
250
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
251
252
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
253
254
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
255
            auto out_lens = compute_broadcasted_lens(s0, s1);
256
257
258
259
260
261
262
263
264

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
265
266
267
268
269
270
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
271
272
    }

Paul's avatar
Paul committed
273
    template <class T>
Paul's avatar
Paul committed
274
275
    void add_generic_op(std::string name, T x)
    {
276
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
277
278
279
280
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
281
    template <class T>
Khalique's avatar
Khalique committed
282
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
283
    {
284
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
285
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
286
287
288
289
290
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
291
        });
Khalique's avatar
Khalique committed
292
293
    }

kahmed10's avatar
kahmed10 committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

313
314
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
315
                            const std::vector<int64_t>& padding,
316
317
318
319
320
                            Op& op,
                            float pad_val = 0)
    {
        if(padding[0] != padding[2] || padding[1] != padding[3])
        {
321
322
323
            ins = prog.add_instruction(
                op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
                ins);
324
325
326
327
328
329
330
331
        }
        else
        {
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
        }
    }

332
333
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
334
    {
kahmed10's avatar
kahmed10 committed
335
336
337
338
339
340
341
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
342
        {
kahmed10's avatar
kahmed10 committed
343
344
345
346
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
347
        }
kahmed10's avatar
kahmed10 committed
348
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
349
        {
kahmed10's avatar
kahmed10 committed
350
351
352
353
354
355
356
357
358
359
360
361
362
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
363
        }
kahmed10's avatar
kahmed10 committed
364
365
366
367
368
369
370
371
372
373
374
375
376

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
377
378
    }

Shucai Xiao's avatar
Shucai Xiao committed
379
    template <class Op>
380
381
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
382
    {
383
        int64_t axis = 1;
384
        if(contains(info.attributes, "axis"))
385
        {
386
            axis = parse_value(info.attributes.at("axis")).at<int>();
387
388
        }

389
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
390
391
    }

Shucai Xiao's avatar
Shucai Xiao committed
392
    template <class Op>
393
394
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
395
    {
396
        int64_t axis = 0;
397
        if(contains(info.attributes, "axis"))
398
        {
399
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
400
401
        }

Shucai Xiao's avatar
Shucai Xiao committed
402
        int keep_dims = 1;
403
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
404
        {
405
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
406
407
        }

Shucai Xiao's avatar
Shucai Xiao committed
408
        if(keep_dims == 0)
409
        {
410
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
411
            return prog.add_instruction(op::squeeze{{axis}}, ins);
412
413
414
        }
        else
        {
415
            return prog.add_instruction(Op{axis}, std::move(args));
416
        }
417
418
    }

419
420
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
421
                                               node_info info,
422
                                               Op& op,
423
424
425
426
                                               std::array<std::size_t, 2> k_lens,
                                               std::array<std::size_t, 2> dilation,
                                               const std::vector<std::size_t>& in_lens,
                                               float value = 0.0f)
427
    {
428
        if(!contains(info.attributes, "auto_pad"))
429
430
431
432
        {
            return ins;
        }

433
        auto auto_pad = info.attributes["auto_pad"].s();
434
435
        if(auto_pad.find("SAME") != std::string::npos)
        {
436
437
438
439
440
441
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            std::vector<int64_t> padding(in_lens.size());
            calculate_padding(
                0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
            calculate_padding(
                1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
442

443
            check_asym_padding(ins, padding, op, value);
444
445
446
447
448
        }

        return ins;
    }

449
    template <class Op>
Paul's avatar
Paul committed
450
    instruction_ref
451
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
452
    {
453
        Op op;
454
455
        auto l0      = args[0];
        auto weights = args[1];
456
        std::vector<int64_t> padding;
457
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
458
        {
459
            if(contains(info.attributes, "auto_pad"))
460
            {
461
462
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
463
                {
464
465
                    MIGRAPHX_THROW(
                        "PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
466
                }
467
            }
468
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
469
            if(padding.size() != 4)
470
            {
471
                MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
472
            }
473
            check_asym_padding(l0, padding, op);
Paul's avatar
Paul committed
474
        }
475
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
476
        {
477
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
478
        }
479
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
480
        {
481
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
482
        }
483
        if(contains(info.attributes, "auto_pad"))
484
        {
485
            auto s = info.attributes["auto_pad"].s();
wsttiger's avatar
fixes  
wsttiger committed
486
            if(s.find("SAME") != std::string::npos)
487
            {
488
489
490
491
492
493
                op.padding_mode                 = op::padding_mode_t::same;
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
494
                padding.resize(input_dims.size());
495
496
497
498
499
500
                calculate_padding(
                    0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(
                    1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                check_asym_padding(l0, padding, op);
501
            }
502
503
504
505
506

            auto in_lens                      = args[0]->get_shape().lens();
            auto weight_lens                  = args[1]->get_shape().lens();
            std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
            l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
507
        }
508
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
509
        {
510
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
511
        }
kahmed10's avatar
kahmed10 committed
512
513
514
515
516

        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

517
518
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
519
520
521
522
523
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
        bool asymm_padding = false;
524
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
525
        {
526
            if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
527
            {
528
529
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
530
531
532
533
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
            }
534
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
535
536
537
538
539
540
541
542
543
544
545
546
547
548
            if(padding.size() != 4)
            {
                MIGRAPHX_THROW("padding should have 4 values");
            }
            if(padding[0] != padding[2] || padding[1] != padding[3])
            {
                asymm_padding = true;
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
549
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
550
        {
551
            copy(info.attributes["strides"].ints(), op.stride.begin());
kahmed10's avatar
kahmed10 committed
552
        }
553
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
554
        {
555
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
556
        }
557
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
558
        {
559
560
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
561
562
563
564
565
566
567
568
569
570
            {
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

571
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
572
        {
573
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        }

        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
        std::vector<int64_t> curr_shape{dims[2], dims[3]};
        if(asymm_padding)
        {
            op::slice slice_op;
            slice_op.axes   = {0, 1, 2, 3};
            slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
            slice_op.ends   = {
                dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};

            l1 = prog.add_instruction(slice_op, l1);
        }

590
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
591
592
        {
            std::vector<int64_t> output_padding;
593
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
594
595
596
597
            output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
            l1             = prog.add_instruction(op::pad{output_padding}, l1);
        }

598
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
599
600
        {
            std::vector<int64_t> output_shape;
601
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            dims       = to_int64_vector(l1->get_shape().lens());
            curr_shape = {dims[2], dims[3]};
            if(curr_shape != output_shape)
            {
                std::vector<int64_t> target_padding = {0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       output_shape[0] - curr_shape[0],
                                                       output_shape[1] - curr_shape[1]};
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
619
    }
Paul's avatar
Paul committed
620

621
622
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
623
    {
Khalique's avatar
Khalique committed
624
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
625
        auto l0 = args[0];
Khalique's avatar
Khalique committed
626
        if(starts_with(name, "Global"))
627
        {
Khalique's avatar
Khalique committed
628
629
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
630
        }
631

632
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
633
        {
634
            if(contains(info.attributes, "auto_pad"))
635
            {
636
                auto s = info.attributes["auto_pad"].s();
637
638
639
640
641
642
643
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

644
            std::vector<std::int64_t> padding;
645
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
646
            if(padding.size() != 4)
647
            {
648
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
649
            }
650
651
652
653
            float pad_val = 0;
            if(op.mode == "max")
                pad_val = std::numeric_limits<float>::lowest();
            check_asym_padding(l0, padding, op, pad_val);
Paul's avatar
Paul committed
654
        }
655

656
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
657
        {
658
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
659
        }
660
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
661
        {
662
            copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
Paul's avatar
Paul committed
663
        }
664

665
        if(contains(info.attributes, "auto_pad"))
666
        {
667
668
669
670
671
672
            auto s = info.attributes["auto_pad"].s();
            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }

673
            auto in_lens = args[0]->get_shape().lens();
674
675
676
677
678
679
680
681
            float val    = 0.0f;
            // MaxPool
            if(op.mode == "max")
            {
                val = std::numeric_limits<float>::lowest();
            }

            l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
682
683
        }

684
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
685
686
    }

Paul's avatar
Paul committed
687
    instruction_ref
688
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
689
    {
690
        op::reshape op;
Paul's avatar
Paul committed
691
692
        if(args.size() == 1)
        {
693
            literal s = parse_value(info.attributes.at("shape"));
694
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
695
696
697
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
698
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
699
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
700
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
701
        }
702

Shucai Xiao's avatar
Shucai Xiao committed
703
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
704
705
    }

Paul's avatar
Paul committed
706
    instruction_ref
707
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
708
    {
709
        int64_t axis = 1;
710
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
711
        {
712
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
713
        }
714
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
715
716
    }

717
    instruction_ref
718
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
719
720
    {
        op::squeeze op;
721
        literal s = parse_value(info.attributes.at("axes"));
722
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
723
        return prog.add_instruction(op, make_contiguous(args[0]));
724
725
726
    }

    instruction_ref
727
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
728
729
    {
        op::unsqueeze op;
730
        literal s = parse_value(info.attributes.at("axes"));
731
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
732
        return prog.add_instruction(op, make_contiguous(args[0]));
733
734
    }

Scott Thornton's avatar
Scott Thornton committed
735
    instruction_ref
736
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
737
    {
Shucai Xiao's avatar
Shucai Xiao committed
738
        // change to hande axis to be negative values
739
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
740
741
742
743
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

744
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
745
746
747
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
748

749
    instruction_ref
750
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
751
    {
752
        int axis = 0;
753
        if(contains(info.attributes, "axis"))
754
        {
755
            axis = parse_value(info.attributes.at("axis")).at<int>();
756
        }
757

758
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
759
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
760
761
    }

762
    instruction_ref
763
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
764
765
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
788
        {
789
            literal s = parse_value(info.attributes.at("axes"));
790
791
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
792
793

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
794
        {
Shucai Xiao's avatar
Shucai Xiao committed
795
796
797
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
798
        }
Shucai Xiao's avatar
Shucai Xiao committed
799
        else if(contains(info.attributes, "ends"))
800
        {
801
802
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
803
        }
Shucai Xiao's avatar
Shucai Xiao committed
804
805
806
807
808
809
810
811

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
812
        {
813
            literal s = parse_value(info.attributes.at("starts"));
814
815
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
816

817
818
819
        return prog.add_instruction(op, args[0]);
    }

820
821
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
822
    {
823
        literal v = parse_value(info.attributes.at("value"));
824
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
825
        if(v.get_shape().elements() == 0)
826
827
828
829
        {
            return prog.add_literal(literal{});
        }

830
        auto dim_size = info.attributes.at("value").t().dims_size();
831
832
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
833
        {
834
            migraphx::shape scalar_shape{v.get_shape().type()};
835
836
837
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
838
839
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
840

Paul's avatar
Paul committed
841
    instruction_ref
842
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
843
844
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
845
        float beta  = 1.0f;
Paul's avatar
Paul committed
846
847
        bool transa = false;
        bool transb = false;
848
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
849
        {
850
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
851
        }
852
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
853
        {
854
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
855
        }
856
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
857
        {
858
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
859
        }
860
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
861
        {
862
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
863
        }
864
865
866
867
868
869

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

870
871
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
872
873
        if(args.size() == 3)
        {
874
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
875
            {
Shucai Xiao's avatar
Shucai Xiao committed
876
                auto out_lens   = l1->get_shape().lens();
877
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
878
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
879
880
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
881
                {
882
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
883
                }
884
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
885
            }
Paul's avatar
Paul committed
886
        }
887
888

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
889
890
    }

891
    template <class Op>
892
    instruction_ref
893
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
894
    {
Shucai Xiao's avatar
Shucai Xiao committed
895
896
        auto l0      = args[0];
        auto l1      = args[1];
897
898
899
900
901
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
902
        if(l0_lens.size() == 1)
903
904
905
906
907
908
909
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
910
        if(l1_lens.size() == 1)
911
912
913
914
915
916
917
918
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
919
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
920
921
922
923
924
925
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
926
            l0_broadcasted_lens = output_lens;
927
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
928
            l1_broadcasted_lens = output_lens;
929
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
930
            if(l0_lens != l0_broadcasted_lens)
931
932
933
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
934
            if(l1_lens != l1_broadcasted_lens)
935
936
937
938
939
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

940
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
941
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
942
        if(is_a_prepended)
943
944
945
946
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
947
        if(is_b_appended)
948
949
950
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
951

952
953
954
        return dot_res;
    }

955
    instruction_ref
956
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
957
    {
Scott Thornton's avatar
Scott Thornton committed
958
959
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
960
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
961
        if(contains(info.attributes, "epsilon"))
962
        {
963
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
964
        }
965
        if(contains(info.attributes, "momentum"))
966
        {
967
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
968
        }
969
        if(contains(info.attributes, "spatial"))
970
        {
971
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
972
973
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
974
        }
Paul's avatar
Paul committed
975
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
976
        return prog.add_instruction(op, std::move(args));
977
978
    }

979
980
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
981
982
983
984
985
986
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
987
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
988
        {
989
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1014
1015
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1016
    {
Khalique's avatar
Khalique committed
1017
        float alpha = 0.01; // default alpha val for leaky relu
1018
        if(contains(info.attributes, "alpha"))
1019
        {
1020
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1021
1022
1023
1024
1025
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1026
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1027
1028
    {
        float alpha = 1.0; // default alpha val for elu
1029
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1030
        {
1031
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1032
1033
1034
1035
1036
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1037
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1038
1039
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1040
1041
1042
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1043
1044
1045
1046
1047
1048
1049
1050
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1051
1052
1053
1054
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1055
1056
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1057
1058
1059
    {
        float scale = 1.0;
        std::vector<float> bias{};
1060
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1061
        {
1062
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1063
1064
        }

1065
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1066
        {
1067
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1068
1069
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1070
1071
1072
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1073

Shucai Xiao's avatar
Shucai Xiao committed
1074
1075
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1076

1077
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1078
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1079
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1080
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1081
    }
Khalique's avatar
Khalique committed
1082

Khalique's avatar
Khalique committed
1083
    instruction_ref
1084
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1085
1086
    {
        std::vector<int64_t> perm{};
1087
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1088
        {
1089
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1090
1091
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1092
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1093
1094
    }

1095
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1096
1097
1098
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
1099
        if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1100
        {
1101
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1102
1103
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1104
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1105
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1106
1107
1108
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1109
        if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1110
        {
1111
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1112
        }
1113
        if(contains(info.attributes, "mode"))
Khalique's avatar
Khalique committed
1114
        {
1115
            auto mode = info.attributes.at("mode").s();
Khalique's avatar
Khalique committed
1116
1117
1118
1119
1120
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1121
1122
1123
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1124
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1125
1126
    {
        if(args.size() != 1)
1127
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1140
1141
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1142
1143
1144
1145
1146
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1147
        if(contains(info.attributes, "dtype"))
1148
        {
1149
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1150
        }
Shucai Xiao's avatar
Shucai Xiao committed
1151
        shape::type_t type = get_type(dtype);
1152

1153
        if(contains(info.attributes, "input_as_shape"))
1154
        {
1155
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1156
1157
        }

1158
        if(contains(info.attributes, "value"))
1159
        {
1160
            value = parse_value(info.attributes.at("value")).at<float>();
1161
1162
        }

1163
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1164
        {
1165
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1166
1167
        }

1168
1169
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1170
            if(args.size() != 1)
1171
            {
1172
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1173
1174
            }

1175
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1176
            {
1177
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1178
                               "at the same time");
1179
1180
            }

1181
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1182
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1183

1184
1185
1186
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1187
1188
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1189
1190
1191
        }
        else if(input_as_shape == 0)
        {
1192
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1193
            {
1194
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1195
1196
            }

1197
            literal ls = parse_value(info.attributes.at("shape"));
1198
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1199
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1200
            migraphx::shape s{type, dims};
1201
1202
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1203
1204
1205
        }
        else
        {
1206
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1207
1208
1209
        }
    }

1210
1211
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1212
1213
    {
        literal l_val{};
1214
        if(contains(info.attributes, "value"))
1215
        {
1216
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1217
            if(l_val.get_shape().elements() != 1)
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1229

Shucai Xiao's avatar
Shucai Xiao committed
1230
        if(args.empty())
1231
        {
Shucai Xiao's avatar
Shucai Xiao committed
1232
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1233
1234
1235
        }
        else
        {
1236
1237
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1238
            if(args[0]->get_shape().elements() == 0)
1239
            {
1240
                s = migraphx::shape{type, {1}, {0}};
1241
            }
1242
1243
1244
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1245
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1246

1247
1248
1249
1250
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1251

Shucai Xiao's avatar
Shucai Xiao committed
1252
            literal l_out{};
1253
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1254
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1255
                // l_val contains only one element
1256
                std::vector<val_type> out_vec(s.elements(), val.front());
1257
1258
1259
1260
1261
1262
1263
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1264
    instruction_ref
1265
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1266
    {
Shucai Xiao's avatar
Shucai Xiao committed
1267
        auto in_lens             = args[0]->get_shape().lens();
1268
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1269
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1270
1271
1272
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1273
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1274
1275
    }

Shucai Xiao's avatar
Shucai Xiao committed
1276
    std::vector<instruction_ref>
1277
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1278
1279
    {
        migraphx::shape input_shape = args[0]->get_shape();
1280
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1281

1282
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1283
        {
1284
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1285
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1286
1287
1288
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1289
1290
1291
1292
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1293
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1294
        {
1295
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1296
1297
        }

1298
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1299
1300
        if(direction == "bidirectional")
        {
1301
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1302
1303
1304
        }
        else if(direction == "reverse")
        {
1305
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1306
1307
        }

1308
        std::vector<std::string> vec_names{"tanh"};
1309
        if(contains(info.attributes, "activations"))
1310
        {
1311
            auto names = info.attributes.at("activations").strings();
1312
            vec_names.clear();
1313
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1314
1315
1316
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1317
1318
        }

1319
1320
1321
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1322
        if(name_it != vec_names.end())
1323
1324
1325
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1326

Shucai Xiao's avatar
Shucai Xiao committed
1327
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1328
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1329
        // if only one actv function is provided, we use it in both
1330
        // forward and reverse direction
1331
        if(dirct == op::rnn_direction::bidirectional)
1332
        {
Shucai Xiao's avatar
Shucai Xiao committed
1333
            if(vec_names.size() == 1)
1334
1335
1336
1337
1338
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1339
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1340
1341
1342
1343
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1344

Shucai Xiao's avatar
Shucai Xiao committed
1345
1346
        // To be added later
        float clip = 0.0;
1347
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1348
        {
1349
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1350
1351
        }

1352
1353
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1354
        if(args.size() < 6)
1355
1356
1357
1358
1359
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1360
1361
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1362
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1363

1364
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1365
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1366

Shucai Xiao's avatar
Shucai Xiao committed
1367
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1368
1369
    }

1370
    std::vector<instruction_ref>
1371
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1372
1373
1374
1375
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1376
        if(contains(info.attributes, "hidden_size"))
1377
        {
1378
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1379
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1380
1381
1382
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1383
1384
1385
1386
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1387
        if(contains(info.attributes, "direction"))
1388
        {
1389
            direction = info.attributes.at("direction").s();
1390
1391
        }

1392
        op::rnn_direction dirct = op::rnn_direction::forward;
1393
1394
        if(direction == "bidirectional")
        {
1395
            dirct = op::rnn_direction::bidirectional;
1396
1397
1398
        }
        else if(direction == "reverse")
        {
1399
            dirct = op::rnn_direction::reverse;
1400
1401
        }

1402
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1403
        if(contains(info.attributes, "activations"))
1404
        {
1405
            auto names = info.attributes.at("activations").strings();
1406
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1407
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1408
1409
1410
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1411
1412
        }

1413
        // need 4 activation functions
1414
        if(dirct == op::rnn_direction::bidirectional)
1415
        {
Shucai Xiao's avatar
Shucai Xiao committed
1416
            // 4 activation functions are used in the bidirectional
1417
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1418
1419
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1420
1421
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1422
1423
1424
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1425
            if(vec_names.size() == 1)
1426
            {
1427
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1428
            }
1429
            else if(vec_names.size() == 2)
1430
            {
1431
1432
1433
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1434
            }
1435
            else if(vec_names.size() == 3)
1436
            {
1437
                vec_names.push_back(vec_names.at(2));
1438
1439
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1440
        else
1441
        {
1442
            if(vec_names.size() == 1)
1443
            {
1444
                vec_names.push_back(vec_names.at(0));
1445
1446
1447
            }
        }

1448
1449
1450
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1451
        if(name_it != vec_names.end())
1452
1453
1454
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1455

Shucai Xiao's avatar
Shucai Xiao committed
1456
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1457
1458
1459
1460
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1461
1462

        float clip = 0.0;
1463
        if(contains(info.attributes, "clip"))
1464
        {
1465
            clip = parse_value(info.attributes.at("clip")).at<float>();
1466
1467
1468
        }

        int linear_before_reset = 0;
1469
        if(contains(info.attributes, "linear_before_reset"))
1470
        {
1471
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1472
1473
        }

Shucai Xiao's avatar
Shucai Xiao committed
1474
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1475
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1476
1477
1478
1479
1480
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1481
1482
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1483
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1484
            std::move(args));
1485
1486

        // second output for last gru output
1487
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1488

Shucai Xiao's avatar
Shucai Xiao committed
1489
        return {hidden_states, last_output};
1490
1491
    }

Shucai Xiao's avatar
Shucai Xiao committed
1492
    std::vector<instruction_ref>
1493
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1494
1495
1496
1497
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1498
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1499
        {
1500
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1501
1502
1503
1504
1505
1506
1507
1508
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1509
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1510
        {
1511
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1512
1513
        }

Shucai Xiao's avatar
Shucai Xiao committed
1514
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1515
1516
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1517
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1518
1519
1520
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1521
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1522
        }
Shucai Xiao's avatar
Shucai Xiao committed
1523
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1524
        {
Shucai Xiao's avatar
Shucai Xiao committed
1525
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1526
1527
1528
1529
1530
1531
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1532
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1533
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1534
        {
1535
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1536
1537
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1538
1539
1540
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1541
1542
1543
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1544
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1545
1546
1547
1548
1549
1550
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1551
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1552
1553
1554
1555
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1556
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1557
1558
1559
1560
1561
1562
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1563
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1564
1565
1566

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1567
1568
1569
1570
1571
1572
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1573
1574
1575
1576
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1577
1578
1579
1580
1581
1582
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1583
1584
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1585
1586
1587
1588
1589
1590
1591
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1592
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1593

Shucai Xiao's avatar
Shucai Xiao committed
1594
1595
1596
1597
1598
1599
1600
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1601
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1602

Shucai Xiao's avatar
Shucai Xiao committed
1603
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1604
1605
1606
1607
1608
1609
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1610
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1611
1612
1613

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1614
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1615
1616
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1617
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1618
1619
1620
            }
        }

1621
1622
1623
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1624
        if(name_it != vec_names.end())
1625
1626
1627
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1628
1629

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1630
1631
1632
1633
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1634
1635

        float clip = 0.0;
1636
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1637
        {
1638
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1639
1640
1641
        }

        int input_forget = 0;
1642
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
1643
        {
1644
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1645
1646
1647
1648
1649
1650
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1651
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1652
1653
1654
1655
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1656
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1657
1658

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1659
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1660
1661
1662
1663
1664
1665

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1666

Shucai Xiao's avatar
Shucai Xiao committed
1667
    template <class T>
1668
1669
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1670
1671
1672
1673
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1674
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1675
        std::iota(axes.begin(), axes.end(), 0);
1676
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
1677
1678
        {
            axes.clear();
1679
            auto&& attr_axes = info.attributes["axes"].ints();
1680
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1681
1682
1683
        }

        int keep_dims = 1;
1684
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
1685
        {
1686
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1687
1688
1689
1690
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1691
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1692
1693
1694
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1695
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1696
            return prog.add_instruction(op::squeeze{axes}, ins);
1697
1698
        }
    }
1699

Shucai Xiao's avatar
Shucai Xiao committed
1700
    instruction_ref
1701
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1702
1703
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
1704
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1705
1706
1707
    }

    instruction_ref
1708
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1709
1710
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1711
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1712
1713
1714
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

1715
1716
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1717
    {
1718
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1719
1720
1721
        return prog.add_instruction(op::log{}, sum_ins);
    }

1722
1723
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1724
1725
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
1726
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1727
1728
1729
        return prog.add_instruction(op::log{}, sum_ins);
    }

1730
1731
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1732
1733
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1734
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1735
1736
    }

Shucai Xiao's avatar
Shucai Xiao committed
1737
    instruction_ref
1738
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
1739
    {
1740
        if(!contains(info.attributes, "to"))
1741
1742
1743
1744
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

1745
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
1746
1747
1748
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1749

1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

Paul's avatar
Paul committed
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1815
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1816
1817
1818
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
1835
1836
    void parse_graph(const onnx::GraphProto& graph)
    {
1837
        for(auto&& f : graph.initializer())
1838
1839
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1840
1841
1842
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1843
1844
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1845
1846
            {
                // TODO: Get shape of input parameter
1847
                shape s            = parse_type(input.type(), batch_size);
1848
1849
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1850
        }
1851
1852

        for(auto&& node : graph.node())
Paul's avatar
Paul committed
1853
        {
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(input.empty())
                {
                    this->parse_undefined(input);
                }
                if(instructions.count(input) == 0)
                {
                    MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
                                   "\" is unavailable due to unordered nodes!");
                }
                args.push_back(instructions.at(input));
            }

            std::vector<instruction_ref> result;
            std::size_t output_num = static_cast<std::size_t>(node.output().size());
            if(ops.count(node.op_type()) == 0)
            {
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
            }
            else
            {
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
            }

            output_num = std::min<std::size_t>(output_num, result.size());
            std::transform(node.output().begin(),
                           node.output().begin() + output_num,
                           result.begin(),
                           std::inserter(instructions, instructions.end()),
                           [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
1886
        }
Shucai Xiao's avatar
Shucai Xiao committed
1887

1888
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
1889
        auto prog_output = graph.output();
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
1910
1911
    }

Shucai Xiao's avatar
Shucai Xiao committed
1912
    void parse_undefined(const std::string& name)
1913
    {
Shucai Xiao's avatar
Shucai Xiao committed
1914
        auto ins           = prog.add_instruction(op::undefined{});
1915
1916
1917
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
1942
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1943
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
1944
1945
1946
1947
1948
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
1949
1950
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
1951
1952
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
1953
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
1954
1955
1956
1957
1958
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
1959
1960
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
1961
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
1962
1963
            switch(t.data_type())
            {
1964
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
1965
1966
1967
1968
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
1969
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
1970
1971
1972
1973
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
1974
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
1975
1976
1977
1978
1979
1980
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
1981
1982
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
1983
            MIGRAPHX_THROW("Invalid tensor type");
1984
        }
Paul's avatar
Paul committed
1985
1986
1987
1988
1989
1990
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
1991
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
1992
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1993
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
1994
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
1995
1996
1997
1998
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
1999
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2000
        {
Khalique's avatar
Khalique committed
2001
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2002
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2003
2004
2005
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2006
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2007
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2008
        }
Paul's avatar
Paul committed
2009
2010
2011
2012
2013
2014
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2015
2016
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2017
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2018
2019
    }

Khalique's avatar
Khalique committed
2020
    static literal
2021
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2022
    {
Khalique's avatar
Khalique committed
2023
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2024
        if(dims.empty())
2025
            return literal{{shape_type}, data};
2026
2027
2028
        return literal{{shape_type, dims}, data};
    }

2029
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2030
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2031
2032
    {
        if(dims.empty())
2033
            return literal{{shape_type}, data.begin(), data.end()};
2034
        return literal{{shape_type, dims}, data.begin(), data.end()};
2035
2036
    }

2037
    static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
Paul's avatar
Paul committed
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2048
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2049
2050
2051
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
2052
2053
2054
2055
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2056
2057
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2058
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2059
2060
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2061
        auto&& tensor_dims = t.tensor_type().shape().dim();
2062
2063
2064
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2065
2066
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2067
                           {
2068
2069
2070
                               if(static_cast<int>(d.dim_value()) <= 0)
                                   return batch_size;
                               return d.dim_value();
2071
                           }
2072
                           return batch_size;
2073
                       });
2074
2075
2076
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2077
2078
        return {shape_type, dims};
    }
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2101
2102
2103

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2104
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2105
2106
2107
2108
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2109
2110
};

Paul Fultz II's avatar
Paul Fultz II committed
2111
2112
template <class... Ts>
program parse_onnx_from(onnx_options options, Ts&&... xs)
Paul's avatar
Paul committed
2113
2114
{
    onnx_parser parser;
2115
    parser.batch_size = options.batch_size;
Paul's avatar
Paul committed
2116
2117
2118
2119
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
Paul Fultz II's avatar
Paul Fultz II committed
2120
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2121
2122
2123
2124
2125
2126
2127
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
Paul Fultz II's avatar
Paul Fultz II committed
2128
    parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2129
2130
2131
2132
#endif
    return std::move(parser.prog);
}

Paul Fultz II's avatar
Paul Fultz II committed
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
program parse_onnx(const std::string& name, onnx_options options)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

program parse_onnx_buffer(const std::string& buffer, onnx_options options)
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options)
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2149
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2150
} // namespace migraphx