onnx.cpp 84.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
27
28
29
30
31
32
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
33
    using op_func =
34
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
35
36
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
37
38
39
40
    program prog                  = program();
    bool is_pytorch               = false;
    std::size_t default_dim_value = 1;
    std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
41
    bool skip_unknown_operators = false;
Paul's avatar
Paul committed
42
43

    std::unordered_map<std::string, op_func> ops;
44
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
45
46
47

    onnx_parser()
    {
48
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
49
        add_generic_op("Abs", op::abs{});
50
51
52
53
54
55
56
57
58
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
59
        add_generic_op("Erf", op::erf{});
60
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
61
        add_generic_op("Dropout", op::identity{});
62
63
        add_generic_op("Log", op::log{});
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
64
        add_generic_op("Identity", op::identity{});
kahmed10's avatar
kahmed10 committed
65
        add_generic_op("Reciprocal", op::recip{});
66
67
68
69
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
70
        add_generic_op("Sin", op::sin{});
71
        add_generic_op("Sinh", op::sinh{});
72
        add_generic_op("Sqrt", op::sqrt{});
73
74
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
75

Khalique's avatar
Khalique committed
76
77
78
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
79
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
80
        add_binary_op("PRelu", op::prelu{});
81
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
82

Khalique's avatar
Khalique committed
83
84
85
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
86

87
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
88
89
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
90
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
91
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
92
        add_mem_op("Clip", &onnx_parser::parse_clip);
93
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
94
        add_mem_op("Constant", &onnx_parser::parse_constant);
95
96
97
98
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
99
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
100
101
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
102
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
103
        add_mem_op("Gather", &onnx_parser::parse_gather);
Paul's avatar
Paul committed
104
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
105
106
107
108
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
109
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
110
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
111
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
112
113
114
115
        add_mem_op("LRN", &onnx_parser::parse_lrn);
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
kahmed10's avatar
kahmed10 committed
116
        add_mem_op("OneHot", &onnx_parser::parse_onehot);
Shucai Xiao's avatar
Shucai Xiao committed
117
118
119
120
121
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
122
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
123
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
124
125
126
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
127
128
129
130
131
132
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Pad", &onnx_parser::parse_pad);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
133
        add_mem_op("Split", &onnx_parser::parse_split);
134
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
kahmed10's avatar
kahmed10 committed
135
        add_mem_op("Tile", &onnx_parser::parse_tile);
136
137
138
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
139
140
141
142
143
144
145

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
146
147
148
149
150
151
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
152
153
154
155
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
156
157
158
159
160
161
162
163
164
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
165
166
167
168
169
170
171
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
172
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
173
174
175
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
176

177
    template <class T>
Khalique's avatar
Khalique committed
178
    void add_binary_op(std::string name, T x)
179
    {
180
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
181
            if(args.size() != 2)
Paul's avatar
Paul committed
182
                MIGRAPHX_THROW("binary operators should have 2 operands");
183
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
184
            {
185
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
186
187
                if(broadcasted != 0)
                {
188
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
189
190
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
191
192
                    return prog.add_instruction(x, args[0], l);
                }
193
                return prog.add_instruction(x, args);
194
            }
Paul's avatar
Paul committed
195
            else
196
            {
Khalique's avatar
Khalique committed
197
                return add_broadcastable_binary_op(args[0], args[1], x);
198
199
200
201
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
202
203
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
204
205
206
207
208
209
210
211
212
213
214
215
216
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
217
        if(s0.size() > s1.size())
218
219
220
221
222
223
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
224
225
226
227
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
228
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
229
                           if(a != b and a != 1 and b != 1)
230
                           {
Shucai Xiao's avatar
Shucai Xiao committed
231
232
233
234
235
236
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
237
238
239
240

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
241
242
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
243
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
244
245
246
247
248
249
250
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
251
252
253
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
254
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
255
256
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
257
258
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
259
            auto out_lens = compute_broadcasted_lens(s0, s1);
260
261
262
263
264
265
266
267
268

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
269
270
271
272
273
274
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
275
276
    }

Paul's avatar
Paul committed
277
    template <class T>
Paul's avatar
Paul committed
278
279
    void add_generic_op(std::string name, T x)
    {
280
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
281
282
283
284
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
285
    template <class T>
Khalique's avatar
Khalique committed
286
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
287
    {
288
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
289
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
290
291
292
293
294
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
295
        });
Khalique's avatar
Khalique committed
296
297
    }

kahmed10's avatar
kahmed10 committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

317
318
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
319
                            const std::vector<int64_t>& padding,
320
321
322
323
324
                            Op& op,
                            float pad_val = 0)
    {
        if(padding[0] != padding[2] || padding[1] != padding[3])
        {
325
326
327
            ins = prog.add_instruction(
                op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
                ins);
328
329
330
331
332
333
334
335
        }
        else
        {
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
        }
    }

336
337
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
338
    {
kahmed10's avatar
kahmed10 committed
339
340
341
342
343
344
345
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
346
        {
kahmed10's avatar
kahmed10 committed
347
348
349
350
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
351
        }
kahmed10's avatar
kahmed10 committed
352
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
353
        {
kahmed10's avatar
kahmed10 committed
354
355
356
357
358
359
360
361
362
363
364
365
366
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
367
        }
kahmed10's avatar
kahmed10 committed
368
369
370
371
372
373
374
375
376
377
378
379
380

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
381
382
    }

Shucai Xiao's avatar
Shucai Xiao committed
383
    template <class Op>
384
385
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
386
    {
387
        int64_t axis = 1;
388
        if(contains(info.attributes, "axis"))
389
        {
390
            axis = parse_value(info.attributes.at("axis")).at<int>();
391
392
        }

393
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
394
395
    }

Shucai Xiao's avatar
Shucai Xiao committed
396
    template <class Op>
397
398
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
399
    {
400
        int64_t axis = 0;
401
        if(contains(info.attributes, "axis"))
402
        {
403
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
404
405
        }

Shucai Xiao's avatar
Shucai Xiao committed
406
        int keep_dims = 1;
407
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
408
        {
409
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
410
411
        }

Shucai Xiao's avatar
Shucai Xiao committed
412
        if(keep_dims == 0)
413
        {
414
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
415
            return prog.add_instruction(op::squeeze{{axis}}, ins);
416
417
418
        }
        else
        {
419
            return prog.add_instruction(Op{axis}, std::move(args));
420
        }
421
422
    }

423
424
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
425
                                               node_info info,
426
                                               Op& op,
427
428
429
430
                                               std::array<std::size_t, 2> k_lens,
                                               std::array<std::size_t, 2> dilation,
                                               const std::vector<std::size_t>& in_lens,
                                               float value = 0.0f)
431
    {
432
        if(!contains(info.attributes, "auto_pad"))
433
434
435
436
        {
            return ins;
        }

437
        auto auto_pad = info.attributes["auto_pad"].s();
438
439
        if(auto_pad.find("SAME") != std::string::npos)
        {
440
441
442
443
444
445
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            std::vector<int64_t> padding(in_lens.size());
            calculate_padding(
                0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
            calculate_padding(
                1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
446

447
            check_asym_padding(ins, padding, op, value);
448
449
450
451
452
        }

        return ins;
    }

453
    template <class Op>
Paul's avatar
Paul committed
454
    instruction_ref
455
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
456
    {
457
        Op op;
458
459
        auto l0      = args[0];
        auto weights = args[1];
460
        std::vector<int64_t> padding;
461
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
462
        {
463
            if(contains(info.attributes, "auto_pad"))
464
            {
465
466
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
467
                {
468
469
                    MIGRAPHX_THROW(
                        "PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
470
                }
471
            }
472
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
473
            if(padding.size() != 4)
474
            {
475
                MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
476
            }
477
            check_asym_padding(l0, padding, op);
Paul's avatar
Paul committed
478
        }
479
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
480
        {
481
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
482
        }
483
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
484
        {
485
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
486
        }
487
        if(contains(info.attributes, "auto_pad"))
488
        {
489
            auto s = info.attributes["auto_pad"].s();
wsttiger's avatar
fixes  
wsttiger committed
490
            if(s.find("SAME") != std::string::npos)
491
            {
492
493
494
495
496
497
                op.padding_mode                 = op::padding_mode_t::same;
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
498
                padding.resize(input_dims.size());
499
500
501
502
503
504
                calculate_padding(
                    0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(
                    1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                check_asym_padding(l0, padding, op);
505
            }
506
507
508
509
510

            auto in_lens                      = args[0]->get_shape().lens();
            auto weight_lens                  = args[1]->get_shape().lens();
            std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
            l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
511
        }
512
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
513
        {
514
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
515
        }
kahmed10's avatar
kahmed10 committed
516
517
518
519
520

        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

521
522
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
523
524
525
526
527
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
        bool asymm_padding = false;
528
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
529
        {
530
            if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
531
            {
532
533
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
534
535
536
537
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
            }
538
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
539
540
541
542
543
544
545
546
547
548
549
550
551
552
            if(padding.size() != 4)
            {
                MIGRAPHX_THROW("padding should have 4 values");
            }
            if(padding[0] != padding[2] || padding[1] != padding[3])
            {
                asymm_padding = true;
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
553
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
554
        {
555
            copy(info.attributes["strides"].ints(), op.stride.begin());
kahmed10's avatar
kahmed10 committed
556
        }
557
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
558
        {
559
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
560
        }
561
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
562
        {
563
564
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
565
566
567
568
569
570
571
572
573
574
            {
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

575
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
576
        {
577
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        }

        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
        std::vector<int64_t> curr_shape{dims[2], dims[3]};
        if(asymm_padding)
        {
            op::slice slice_op;
            slice_op.axes   = {0, 1, 2, 3};
            slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
            slice_op.ends   = {
                dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};

            l1 = prog.add_instruction(slice_op, l1);
        }

594
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
595
596
        {
            std::vector<int64_t> output_padding;
597
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
598
599
600
601
            output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
            l1             = prog.add_instruction(op::pad{output_padding}, l1);
        }

602
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
603
604
        {
            std::vector<int64_t> output_shape;
605
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
            dims       = to_int64_vector(l1->get_shape().lens());
            curr_shape = {dims[2], dims[3]};
            if(curr_shape != output_shape)
            {
                std::vector<int64_t> target_padding = {0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       output_shape[0] - curr_shape[0],
                                                       output_shape[1] - curr_shape[1]};
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
623
    }
Paul's avatar
Paul committed
624

625
626
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
627
    {
Khalique's avatar
Khalique committed
628
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
629
        auto l0 = args[0];
Khalique's avatar
Khalique committed
630
        if(starts_with(name, "Global"))
631
        {
Khalique's avatar
Khalique committed
632
633
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
634
        }
635

636
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
637
        {
638
            if(contains(info.attributes, "auto_pad"))
639
            {
640
                auto s = info.attributes["auto_pad"].s();
641
642
643
644
645
646
647
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

648
            std::vector<std::int64_t> padding;
649
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
650
            if(padding.size() != 4)
651
            {
652
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
653
            }
654
655
656
657
            float pad_val = 0;
            if(op.mode == "max")
                pad_val = std::numeric_limits<float>::lowest();
            check_asym_padding(l0, padding, op, pad_val);
Paul's avatar
Paul committed
658
        }
659

660
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
661
        {
662
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
663
        }
664
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
665
        {
666
            copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
Paul's avatar
Paul committed
667
        }
668

669
        if(contains(info.attributes, "auto_pad"))
670
        {
671
672
673
674
675
676
            auto s = info.attributes["auto_pad"].s();
            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }

677
            auto in_lens = args[0]->get_shape().lens();
678
679
680
681
682
683
684
685
            float val    = 0.0f;
            // MaxPool
            if(op.mode == "max")
            {
                val = std::numeric_limits<float>::lowest();
            }

            l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
686
687
        }

688
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
689
690
    }

Paul's avatar
Paul committed
691
    instruction_ref
692
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
693
    {
694
        op::reshape op;
Paul's avatar
Paul committed
695
696
        if(args.size() == 1)
        {
697
            literal s = parse_value(info.attributes.at("shape"));
698
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
699
700
701
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
702
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
703
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
704
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
705
        }
706

Shucai Xiao's avatar
Shucai Xiao committed
707
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
708
709
    }

Paul's avatar
Paul committed
710
    instruction_ref
711
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
712
    {
713
        int64_t axis = 1;
714
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
715
        {
716
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
717
        }
718
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
719
720
    }

721
    instruction_ref
722
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
723
724
    {
        op::squeeze op;
725
        literal s = parse_value(info.attributes.at("axes"));
726
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
727
        return prog.add_instruction(op, make_contiguous(args[0]));
728
729
730
    }

    instruction_ref
731
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
732
733
    {
        op::unsqueeze op;
734
        literal s = parse_value(info.attributes.at("axes"));
735
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
736
        return prog.add_instruction(op, make_contiguous(args[0]));
737
738
    }

Scott Thornton's avatar
Scott Thornton committed
739
    instruction_ref
740
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
741
    {
Shucai Xiao's avatar
Shucai Xiao committed
742
        // change to hande axis to be negative values
743
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
744
745
746
747
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

748
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
749
750
751
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
752

753
    instruction_ref
754
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
755
    {
756
        int axis = 0;
757
        if(contains(info.attributes, "axis"))
758
        {
759
            axis = parse_value(info.attributes.at("axis")).at<int>();
760
        }
761

762
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
763
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
764
765
    }

766
    instruction_ref
767
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
768
769
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
792
        {
793
            literal s = parse_value(info.attributes.at("axes"));
794
795
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
796
797

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
798
        {
Shucai Xiao's avatar
Shucai Xiao committed
799
800
801
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
802
        }
Shucai Xiao's avatar
Shucai Xiao committed
803
        else if(contains(info.attributes, "ends"))
804
        {
805
806
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
807
        }
Shucai Xiao's avatar
Shucai Xiao committed
808
809
810
811
812
813
814
815

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
816
        {
817
            literal s = parse_value(info.attributes.at("starts"));
818
819
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
820

kahmed10's avatar
kahmed10 committed
821
822
823
824
825
826
827
        if(op.axes.empty())
        {
            std::vector<int64_t> axes(args[0]->get_shape().lens().size());
            std::iota(axes.begin(), axes.end(), int64_t{0});
            op.axes = axes;
        }

828
829
830
        return prog.add_instruction(op, args[0]);
    }

831
832
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
833
    {
834
        literal v = parse_value(info.attributes.at("value"));
835
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
836
        if(v.get_shape().elements() == 0)
837
838
839
840
        {
            return prog.add_literal(literal{});
        }

841
        auto dim_size = info.attributes.at("value").t().dims_size();
842
843
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
844
        {
845
            migraphx::shape scalar_shape{v.get_shape().type()};
846
847
848
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
849
850
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
851

Paul's avatar
Paul committed
852
    instruction_ref
853
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
854
855
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
856
        float beta  = 1.0f;
Paul's avatar
Paul committed
857
858
        bool transa = false;
        bool transb = false;
859
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
860
        {
861
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
862
        }
863
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
864
        {
865
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
866
        }
867
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
868
        {
869
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
870
        }
871
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
872
        {
873
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
874
        }
875
876
877
878
879
880

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

881
882
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
883
884
        if(args.size() == 3)
        {
885
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
886
            {
Shucai Xiao's avatar
Shucai Xiao committed
887
                auto out_lens   = l1->get_shape().lens();
888
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
889
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
890
891
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
892
                {
893
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
894
                }
895
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
896
            }
Paul's avatar
Paul committed
897
        }
898
899

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
900
901
    }

902
    template <class Op>
903
    instruction_ref
904
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
905
    {
Shucai Xiao's avatar
Shucai Xiao committed
906
907
        auto l0      = args[0];
        auto l1      = args[1];
908
909
910
911
912
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
913
        if(l0_lens.size() == 1)
914
915
916
917
918
919
920
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
921
        if(l1_lens.size() == 1)
922
923
924
925
926
927
928
929
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
930
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
931
932
933
934
935
936
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
937
            l0_broadcasted_lens = output_lens;
938
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
939
            l1_broadcasted_lens = output_lens;
940
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
941
            if(l0_lens != l0_broadcasted_lens)
942
943
944
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
945
            if(l1_lens != l1_broadcasted_lens)
946
947
948
949
950
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

951
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
952
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
953
        if(is_a_prepended)
954
955
956
957
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
958
        if(is_b_appended)
959
960
961
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
962

963
964
965
        return dot_res;
    }

966
    instruction_ref
967
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
968
    {
Scott Thornton's avatar
Scott Thornton committed
969
970
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
971
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
972
        if(contains(info.attributes, "epsilon"))
973
        {
974
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
975
        }
976
        if(contains(info.attributes, "momentum"))
977
        {
978
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
979
        }
980
        if(contains(info.attributes, "spatial"))
981
        {
982
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
983
984
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
985
        }
Paul's avatar
Paul committed
986
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
987
        return prog.add_instruction(op, std::move(args));
988
989
    }

990
991
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
992
993
994
995
996
997
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
998
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
999
        {
1000
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1025
1026
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1027
    {
Khalique's avatar
Khalique committed
1028
        float alpha = 0.01; // default alpha val for leaky relu
1029
        if(contains(info.attributes, "alpha"))
1030
        {
1031
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1032
1033
1034
1035
1036
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1037
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1038
1039
    {
        float alpha = 1.0; // default alpha val for elu
1040
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1041
        {
1042
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1043
1044
1045
1046
1047
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1048
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1049
1050
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1051
1052
1053
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1054
1055
1056
1057
1058
1059
1060
1061
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1062
1063
1064
1065
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1066
1067
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1068
1069
1070
    {
        float scale = 1.0;
        std::vector<float> bias{};
1071
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1072
        {
1073
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1074
1075
        }

1076
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1077
        {
1078
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1079
1080
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1081
1082
1083
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1084

Shucai Xiao's avatar
Shucai Xiao committed
1085
1086
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1087

1088
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1089
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1090
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1091
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1092
    }
Khalique's avatar
Khalique committed
1093

Khalique's avatar
Khalique committed
1094
    instruction_ref
1095
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1096
1097
    {
        std::vector<int64_t> perm{};
1098
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1099
        {
1100
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1101
1102
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1103
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1104
1105
    }

1106
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1107
1108
    {
        std::vector<int64_t> pads{};
1109
1110
1111
1112
1113
1114
1115
        if(args.size() >= 2)
        {
            auto pad_arg = args.at(1)->eval();
            check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
            pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
        }
        else if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1116
        {
1117
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1118
1119
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1120
1121
1122
1123
1124
        else
        {
            MIGRAPHX_THROW("PARSE_PAD: pad must be available");
        }

1125
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1126
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1127
1128
1129
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147

        float value = 0.0f;
        // third input is the value
        if(args.size() == 3)
        {
            auto val_ins = args.at(2);
            if(!val_ins->can_eval())
            {
                MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
            }
            auto val_arg = val_ins->eval();
            if(val_arg.get_shape().elements() != 1)
            {
                MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
            }
            value = val_arg.at<float>();
        }
        else if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1148
        {
1149
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1150
        }
1151

1152
        if(contains(info.attributes, "mode"))
Khalique's avatar
Khalique committed
1153
        {
1154
            auto mode = info.attributes.at("mode").s();
Khalique's avatar
Khalique committed
1155
            if(mode != "constant")
1156
1157
1158
            {
                MIGRAPHX_THROW("PARSE_PAD: migraphx currently only supports constant padding");
            }
Khalique's avatar
Khalique committed
1159
1160
1161
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1162
1163
1164
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1165
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1166
1167
    {
        if(args.size() != 1)
1168
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1181
1182
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1183
1184
1185
1186
1187
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1188
        if(contains(info.attributes, "dtype"))
1189
        {
1190
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1191
        }
Shucai Xiao's avatar
Shucai Xiao committed
1192
        shape::type_t type = get_type(dtype);
1193

1194
        if(contains(info.attributes, "input_as_shape"))
1195
        {
1196
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1197
1198
        }

1199
        if(contains(info.attributes, "value"))
1200
        {
1201
            value = parse_value(info.attributes.at("value")).at<float>();
1202
1203
        }

1204
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1205
        {
1206
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1207
1208
        }

1209
1210
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1211
            if(args.size() != 1)
1212
            {
1213
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1214
1215
            }

1216
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1217
            {
1218
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1219
                               "at the same time");
1220
1221
            }

1222
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1223
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1224

1225
1226
1227
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1228
1229
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1230
1231
1232
        }
        else if(input_as_shape == 0)
        {
1233
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1234
            {
1235
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1236
1237
            }

1238
            literal ls = parse_value(info.attributes.at("shape"));
1239
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1240
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1241
            migraphx::shape s{type, dims};
1242
1243
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1244
1245
1246
        }
        else
        {
1247
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1248
1249
1250
        }
    }

1251
1252
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1253
1254
    {
        literal l_val{};
1255
        if(contains(info.attributes, "value"))
1256
        {
1257
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1258
            if(l_val.get_shape().elements() != 1)
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1270

Shucai Xiao's avatar
Shucai Xiao committed
1271
        if(args.empty())
1272
        {
Shucai Xiao's avatar
Shucai Xiao committed
1273
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1274
1275
1276
        }
        else
        {
1277
1278
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1279
            if(args[0]->get_shape().elements() == 0)
1280
            {
1281
                s = migraphx::shape{type, {1}, {0}};
1282
            }
1283
1284
1285
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1286
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1287

1288
1289
1290
1291
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1292

Shucai Xiao's avatar
Shucai Xiao committed
1293
            literal l_out{};
1294
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1295
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1296
                // l_val contains only one element
1297
                std::vector<val_type> out_vec(s.elements(), val.front());
1298
1299
1300
1301
1302
1303
1304
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1305
    instruction_ref
1306
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1307
    {
Shucai Xiao's avatar
Shucai Xiao committed
1308
        auto in_lens             = args[0]->get_shape().lens();
1309
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1310
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1311
1312
1313
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1314
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1315
1316
    }

Shucai Xiao's avatar
Shucai Xiao committed
1317
    std::vector<instruction_ref>
1318
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1319
1320
    {
        migraphx::shape input_shape = args[0]->get_shape();
1321
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1322

1323
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1324
        {
1325
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1326
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1327
1328
1329
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1330
1331
1332
1333
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1334
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1335
        {
1336
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1337
1338
        }

1339
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1340
1341
        if(direction == "bidirectional")
        {
1342
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1343
1344
1345
        }
        else if(direction == "reverse")
        {
1346
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1347
1348
        }

1349
        std::vector<std::string> vec_names{"tanh"};
1350
        if(contains(info.attributes, "activations"))
1351
        {
1352
            auto names = info.attributes.at("activations").strings();
1353
            vec_names.clear();
1354
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1355
1356
1357
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1358
1359
        }

1360
1361
1362
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1363
        if(name_it != vec_names.end())
1364
1365
1366
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1367

Shucai Xiao's avatar
Shucai Xiao committed
1368
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1369
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1370
        // if only one actv function is provided, we use it in both
1371
        // forward and reverse direction
1372
        if(dirct == op::rnn_direction::bidirectional)
1373
        {
Shucai Xiao's avatar
Shucai Xiao committed
1374
            if(vec_names.size() == 1)
1375
1376
1377
1378
1379
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1380
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1381
1382
1383
1384
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1385

Shucai Xiao's avatar
Shucai Xiao committed
1386
1387
        // To be added later
        float clip = 0.0;
1388
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1389
        {
1390
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1391
1392
        }

1393
1394
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1395
        if(args.size() < 6)
1396
1397
1398
1399
1400
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1401
1402
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1403
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1404

1405
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1406
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1407

Shucai Xiao's avatar
Shucai Xiao committed
1408
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1409
1410
    }

1411
    std::vector<instruction_ref>
1412
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1413
1414
1415
1416
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1417
        if(contains(info.attributes, "hidden_size"))
1418
        {
1419
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1420
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1421
1422
1423
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1424
1425
1426
1427
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1428
        if(contains(info.attributes, "direction"))
1429
        {
1430
            direction = info.attributes.at("direction").s();
1431
1432
        }

1433
        op::rnn_direction dirct = op::rnn_direction::forward;
1434
1435
        if(direction == "bidirectional")
        {
1436
            dirct = op::rnn_direction::bidirectional;
1437
1438
1439
        }
        else if(direction == "reverse")
        {
1440
            dirct = op::rnn_direction::reverse;
1441
1442
        }

1443
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1444
        if(contains(info.attributes, "activations"))
1445
        {
1446
            auto names = info.attributes.at("activations").strings();
1447
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1448
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1449
1450
1451
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1452
1453
        }

1454
        // need 4 activation functions
1455
        if(dirct == op::rnn_direction::bidirectional)
1456
        {
Shucai Xiao's avatar
Shucai Xiao committed
1457
            // 4 activation functions are used in the bidirectional
1458
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1459
1460
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1461
1462
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1463
1464
1465
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1466
            if(vec_names.size() == 1)
1467
            {
1468
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1469
            }
1470
            else if(vec_names.size() == 2)
1471
            {
1472
1473
1474
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1475
            }
1476
            else if(vec_names.size() == 3)
1477
            {
1478
                vec_names.push_back(vec_names.at(2));
1479
1480
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1481
        else
1482
        {
1483
            if(vec_names.size() == 1)
1484
            {
1485
                vec_names.push_back(vec_names.at(0));
1486
1487
1488
            }
        }

1489
1490
1491
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1492
        if(name_it != vec_names.end())
1493
1494
1495
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1496

Shucai Xiao's avatar
Shucai Xiao committed
1497
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1498
1499
1500
1501
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1502
1503

        float clip = 0.0;
1504
        if(contains(info.attributes, "clip"))
1505
        {
1506
            clip = parse_value(info.attributes.at("clip")).at<float>();
1507
1508
1509
        }

        int linear_before_reset = 0;
1510
        if(contains(info.attributes, "linear_before_reset"))
1511
        {
1512
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1513
1514
        }

Shucai Xiao's avatar
Shucai Xiao committed
1515
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1516
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1517
1518
1519
1520
1521
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1522
1523
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1524
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1525
            std::move(args));
1526
1527

        // second output for last gru output
1528
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1529

Shucai Xiao's avatar
Shucai Xiao committed
1530
        return {hidden_states, last_output};
1531
1532
    }

Shucai Xiao's avatar
Shucai Xiao committed
1533
    std::vector<instruction_ref>
1534
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1535
1536
1537
1538
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1539
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1540
        {
1541
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1542
1543
1544
1545
1546
1547
1548
1549
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1550
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1551
        {
1552
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1553
1554
        }

Shucai Xiao's avatar
Shucai Xiao committed
1555
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1556
1557
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1558
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1559
1560
1561
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1562
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1563
        }
Shucai Xiao's avatar
Shucai Xiao committed
1564
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1565
        {
Shucai Xiao's avatar
Shucai Xiao committed
1566
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1567
1568
1569
1570
1571
1572
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1573
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1574
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1575
        {
1576
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1577
1578
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1579
1580
1581
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1582
1583
1584
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1585
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1586
1587
1588
1589
1590
1591
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1592
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1593
1594
1595
1596
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1597
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1598
1599
1600
1601
1602
1603
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1604
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1605
1606
1607

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1608
1609
1610
1611
1612
1613
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1614
1615
1616
1617
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1618
1619
1620
1621
1622
1623
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1624
1625
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1626
1627
1628
1629
1630
1631
1632
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1633
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1634

Shucai Xiao's avatar
Shucai Xiao committed
1635
1636
1637
1638
1639
1640
1641
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1642
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1643

Shucai Xiao's avatar
Shucai Xiao committed
1644
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1645
1646
1647
1648
1649
1650
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1651
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1652
1653
1654

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1655
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1656
1657
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1658
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1659
1660
1661
            }
        }

1662
1663
1664
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1665
        if(name_it != vec_names.end())
1666
1667
1668
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1669
1670

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1671
1672
1673
1674
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1675
1676

        float clip = 0.0;
1677
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1678
        {
1679
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1680
1681
1682
        }

        int input_forget = 0;
1683
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
1684
        {
1685
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1686
1687
1688
1689
1690
1691
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1692
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1693
1694
1695
1696
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1697
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1698
1699

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1700
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1701
1702
1703
1704
1705
1706

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1707

Shucai Xiao's avatar
Shucai Xiao committed
1708
    template <class T>
1709
1710
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1711
1712
1713
1714
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1715
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1716
        std::iota(axes.begin(), axes.end(), 0);
1717
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
1718
1719
        {
            axes.clear();
1720
            auto&& attr_axes = info.attributes["axes"].ints();
1721
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1722
1723
1724
        }

        int keep_dims = 1;
1725
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
1726
        {
1727
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1728
1729
1730
1731
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1732
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1733
1734
1735
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1736
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1737
            return prog.add_instruction(op::squeeze{axes}, ins);
1738
1739
        }
    }
1740

Shucai Xiao's avatar
Shucai Xiao committed
1741
    instruction_ref
1742
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1743
1744
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
1745
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1746
1747
1748
    }

    instruction_ref
1749
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1750
1751
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1752
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1753
1754
1755
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

1756
1757
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1758
    {
1759
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1760
1761
1762
        return prog.add_instruction(op::log{}, sum_ins);
    }

1763
1764
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1765
1766
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
1767
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1768
1769
1770
        return prog.add_instruction(op::log{}, sum_ins);
    }

1771
1772
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1773
1774
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1775
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1776
1777
    }

Shucai Xiao's avatar
Shucai Xiao committed
1778
    instruction_ref
1779
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
1780
    {
1781
        if(!contains(info.attributes, "to"))
1782
1783
1784
1785
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

1786
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
1787
1788
1789
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1790

1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

kahmed10's avatar
kahmed10 committed
1844
1845
1846
1847
    instruction_ref
    parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        migraphx::argument depth_arg = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1848
        check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported");
kahmed10's avatar
kahmed10 committed
1849
1850
1851
        size_t depth = depth_arg.at<size_t>();

        int64_t axis = -1;
Shucai Xiao's avatar
Shucai Xiao committed
1852
1853
1854
1855
        if(contains(info.attributes, "axis"))
        {
            axis = info.attributes.at("axis").i();
        }
kahmed10's avatar
kahmed10 committed
1856

Shucai Xiao's avatar
Shucai Xiao committed
1857
        std::vector<float> depth_input(depth * depth, 0.0f);
kahmed10's avatar
kahmed10 committed
1858
1859
        for(int i = 0; i < depth; i++)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1860
            depth_input[depth * i + i] = 1.0f;
kahmed10's avatar
kahmed10 committed
1861
1862
        }

Shucai Xiao's avatar
Shucai Xiao committed
1863
1864
1865
1866
1867
1868
1869
1870
        auto type = args[2]->get_shape().type();
        shape s{type, {depth, depth}};
        auto l_val      = prog.add_literal({s, depth_input});
        auto gather_out = prog.add_instruction(op::gather{0}, {l_val, args[0]});

        // Finally, we need a transpose to move the inner most dim to the axis dim
        int n_rank = gather_out->get_shape().lens().size();
        if(axis < -n_rank or axis >= n_rank)
kahmed10's avatar
kahmed10 committed
1871
        {
Shucai Xiao's avatar
Shucai Xiao committed
1872
            MIGRAPHX_THROW("PARSE_ONEHOT: axis out of range");
kahmed10's avatar
kahmed10 committed
1873
        }
Shucai Xiao's avatar
Shucai Xiao committed
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
        std::vector<int64_t> perm(n_rank - 1);
        std::iota(perm.begin(), perm.end(), 0);
        perm.insert(perm.begin() + tuned_axis, n_rank - 1);
        auto tr_out = prog.add_instruction(op::transpose{perm}, gather_out);
        auto lens   = tr_out->get_shape().lens();

        auto off_val       = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
        auto on_val        = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
        auto diff          = prog.add_instruction(op::sub{}, on_val, off_val);
        auto unsq_off_val  = prog.add_instruction(op::multibroadcast{lens}, off_val);
        auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff);
        auto l_mul         = prog.add_instruction(op::mul{}, tr_out, unsq_diff_val);
        return prog.add_instruction(op::add{}, l_mul, unsq_off_val);
kahmed10's avatar
kahmed10 committed
1888
1889
    }

kahmed10's avatar
kahmed10 committed
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
    instruction_ref
    parse_tile(const std::string&, const node_info&, std::vector<instruction_ref> args)
    {
        migraphx::argument arg_s = args[1]->eval();
        check_arg_empty(arg_s, "PARSE_TILE: dynamic shape is not supported");
        std::vector<std::int64_t> repeats;
        arg_s.visit([&](auto input) { repeats.assign(input.begin(), input.end()); });

        auto l0 = args[0];
        for(int i = 0; i < repeats.size(); i++)
        {
            auto l1 = l0;
            for(int j = 1; j < repeats[i]; j++)
            {
                l0 = prog.add_instruction(op::concat{i}, l0, l1);
            }
        }
        return l0;
    }

Paul's avatar
Paul committed
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1922
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1923
1924
1925
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
1942
1943
    void parse_graph(const onnx::GraphProto& graph)
    {
1944
        for(auto&& f : graph.initializer())
1945
1946
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1947
1948
1949
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1950
1951
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1952
            {
1953
1954
1955
1956
1957
1958
1959
                std::vector<std::size_t> dims;
                if(map_input_dims.count(name) > 0)
                {
                    dims = map_input_dims.at(name);
                }

                shape s            = parse_type(input.type(), dims);
1960
1961
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1962
        }
1963
1964

        for(auto&& node : graph.node())
Paul's avatar
Paul committed
1965
        {
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(input.empty())
                {
                    this->parse_undefined(input);
                }
                if(instructions.count(input) == 0)
                {
                    MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
                                   "\" is unavailable due to unordered nodes!");
                }
                args.push_back(instructions.at(input));
            }

            std::vector<instruction_ref> result;
            std::size_t output_num = static_cast<std::size_t>(node.output().size());
            if(ops.count(node.op_type()) == 0)
            {
1985
1986
1987
1988
                if(skip_unknown_operators)
                    result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
                else
                    MIGRAPHX_THROW("Unknown operator: " + node.op_type());
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
            }
            else
            {
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
            }

            output_num = std::min<std::size_t>(output_num, result.size());
            std::transform(node.output().begin(),
                           node.output().begin() + output_num,
                           result.begin(),
                           std::inserter(instructions, instructions.end()),
                           [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
2001
        }
Shucai Xiao's avatar
Shucai Xiao committed
2002

2003
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
2004
        auto prog_output = graph.output();
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
2025
2026
    }

Shucai Xiao's avatar
Shucai Xiao committed
2027
    void parse_undefined(const std::string& name)
2028
    {
Shucai Xiao's avatar
Shucai Xiao committed
2029
        auto ins           = prog.add_instruction(op::undefined{});
2030
2031
2032
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
2057
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
2058
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
2059
2060
2061
2062
2063
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
2064
2065
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
2066
2067
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
2068
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
2069
2070
2071
2072
2073
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
2074
2075
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
2076
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
2077
2078
            switch(t.data_type())
            {
2079
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
2080
2081
2082
2083
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
2084
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
2085
2086
2087
2088
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
2089
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
2090
2091
2092
2093
2094
2095
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
2096
2097
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
2098
            MIGRAPHX_THROW("Invalid tensor type");
2099
        }
Paul's avatar
Paul committed
2100
2101
2102
2103
2104
2105
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
2106
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
2107
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
2108
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2109
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
2110
2111
2112
2113
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
2114
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2115
        {
Khalique's avatar
Khalique committed
2116
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2117
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2118
2119
2120
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2121
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2122
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2123
        }
Paul's avatar
Paul committed
2124
2125
2126
2127
2128
2129
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2130
2131
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2132
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2133
2134
    }

Khalique's avatar
Khalique committed
2135
    static literal
2136
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2137
    {
Khalique's avatar
Khalique committed
2138
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2139
        if(dims.empty())
2140
            return literal{{shape_type}, data};
2141
2142
2143
        return literal{{shape_type, dims}, data};
    }

2144
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2145
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2146
2147
    {
        if(dims.empty())
2148
            return literal{{shape_type}, data.begin(), data.end()};
2149
        return literal{{shape_type, dims}, data.begin(), data.end()};
2150
2151
    }

2152
    shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims)
Paul's avatar
Paul committed
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2163
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2164
2165
2166
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
2167
        case onnx::TensorProto::UINT8: shape_type = shape::uint8_type; break;
Paul's avatar
Paul committed
2168
2169
2170
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2171
2172
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2173
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2174
        }
2175
2176
2177
2178
2179
2180

        if(!input_dims.empty())
        {
            return {shape_type, input_dims};
        }

Paul's avatar
Paul committed
2181
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2182
        auto&& tensor_dims = t.tensor_type().shape().dim();
2183
2184
2185
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2186
2187
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2188
                           {
2189
                               if(static_cast<int>(d.dim_value()) <= 0)
2190
2191
2192
                               {
                                   return default_dim_value;
                               }
2193
                               return d.dim_value();
2194
                           }
2195
2196
2197
2198
                           else
                           {
                               return default_dim_value;
                           }
2199
                       });
2200

2201
2202
2203
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2204
2205
        return {shape_type, dims};
    }
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2228
2229
2230

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2231
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2232
2233
2234
2235
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2236
2237
};

Paul Fultz II's avatar
Paul Fultz II committed
2238
template <class... Ts>
2239
program parse_onnx_from(const onnx_options& options, Ts&&... xs)
Paul's avatar
Paul committed
2240
2241
{
    onnx_parser parser;
2242
2243
2244
    parser.map_input_dims         = options.map_input_dims;
    parser.default_dim_value      = options.default_dim_value;
    parser.skip_unknown_operators = options.skip_unknown_operators;
2245

2246
    if(options.print_program_on_error)
Paul's avatar
Paul committed
2247
    {
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
        // Log the program when it can't be parsed
        try
        {
            parser.parse_from(std::forward<Ts>(xs)...);
        }
        catch(...)
        {
            std::cerr << parser.prog << std::endl;
            throw;
        }
Paul's avatar
Paul committed
2258
    }
2259
    else
Paul's avatar
Paul committed
2260
    {
2261
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2262
2263
2264
2265
    }
    return std::move(parser.prog);
}

2266
program parse_onnx(const std::string& name, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2267
2268
2269
2270
2271
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

2272
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2273
2274
2275
2276
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

2277
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2278
2279
2280
2281
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2282
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2283
} // namespace migraphx