onnx.cpp 84.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23

24
25
namespace onnx = onnx_for_migraphx;

Paul's avatar
Paul committed
26
27
28
struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
29
30
31
32
33
34
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
35
    using op_func =
36
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
37
38
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
39
40
41
42
    program prog                  = program();
    bool is_pytorch               = false;
    std::size_t default_dim_value = 1;
    std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
43
    bool skip_unknown_operators = false;
Paul's avatar
Paul committed
44
45

    std::unordered_map<std::string, op_func> ops;
46
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
47
48
49

    onnx_parser()
    {
50
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
51
        add_generic_op("Abs", op::abs{});
52
53
54
55
56
57
58
59
60
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
61
        add_generic_op("Erf", op::erf{});
62
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
63
        add_generic_op("Dropout", op::identity{});
64
65
        add_generic_op("Log", op::log{});
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
66
        add_generic_op("Identity", op::identity{});
kahmed10's avatar
kahmed10 committed
67
        add_generic_op("Reciprocal", op::recip{});
68
69
70
71
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
72
        add_generic_op("Sin", op::sin{});
73
        add_generic_op("Sinh", op::sinh{});
74
        add_generic_op("Sqrt", op::sqrt{});
75
76
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
77

Khalique's avatar
Khalique committed
78
79
80
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
81
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
82
        add_binary_op("PRelu", op::prelu{});
83
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
84

Khalique's avatar
Khalique committed
85
86
87
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
88

89
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
90
91
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
92
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
93
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
94
        add_mem_op("Clip", &onnx_parser::parse_clip);
95
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
96
        add_mem_op("Constant", &onnx_parser::parse_constant);
97
98
99
100
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
101
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
102
103
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
104
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
105
        add_mem_op("Gather", &onnx_parser::parse_gather);
Paul's avatar
Paul committed
106
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
107
108
109
110
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
111
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
112
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
113
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
114
115
116
117
        add_mem_op("LRN", &onnx_parser::parse_lrn);
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
kahmed10's avatar
kahmed10 committed
118
        add_mem_op("OneHot", &onnx_parser::parse_onehot);
Shucai Xiao's avatar
Shucai Xiao committed
119
120
121
122
123
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
124
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
125
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
126
127
128
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
129
130
131
132
133
134
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Pad", &onnx_parser::parse_pad);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
135
        add_mem_op("Split", &onnx_parser::parse_split);
136
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
kahmed10's avatar
kahmed10 committed
137
        add_mem_op("Tile", &onnx_parser::parse_tile);
138
139
140
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
141
142
143
144
145
146
147

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
148
149
150
151
152
153
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
154
155
156
157
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
158
159
160
161
162
163
164
165
166
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
167
168
169
170
171
172
173
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
174
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
175
176
177
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
178

179
    template <class T>
Khalique's avatar
Khalique committed
180
    void add_binary_op(std::string name, T x)
181
    {
182
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
183
            if(args.size() != 2)
Paul's avatar
Paul committed
184
                MIGRAPHX_THROW("binary operators should have 2 operands");
185
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
186
            {
187
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
188
189
                if(broadcasted != 0)
                {
190
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
191
192
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
193
194
                    return prog.add_instruction(x, args[0], l);
                }
195
                return prog.add_instruction(x, args);
196
            }
Paul's avatar
Paul committed
197
            else
198
            {
Khalique's avatar
Khalique committed
199
                return add_broadcastable_binary_op(args[0], args[1], x);
200
201
202
203
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
204
205
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
206
207
208
209
210
211
212
213
214
215
216
217
218
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
219
        if(s0.size() > s1.size())
220
221
222
223
224
225
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
226
227
228
229
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
230
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
231
                           if(a != b and a != 1 and b != 1)
232
                           {
Shucai Xiao's avatar
Shucai Xiao committed
233
234
235
236
237
238
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
239
240
241
242

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
243
244
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
245
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
246
247
248
249
250
251
252
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
253
254
255
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
256
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
257
258
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
259
260
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
261
            auto out_lens = compute_broadcasted_lens(s0, s1);
262
263
264
265
266
267
268
269
270

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
271
272
273
274
275
276
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
277
278
    }

Paul's avatar
Paul committed
279
    template <class T>
Paul's avatar
Paul committed
280
281
    void add_generic_op(std::string name, T x)
    {
282
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
283
284
285
286
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
287
    template <class T>
Khalique's avatar
Khalique committed
288
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
289
    {
290
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
291
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
292
293
294
295
296
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
297
        });
Khalique's avatar
Khalique committed
298
299
    }

kahmed10's avatar
kahmed10 committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

319
320
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
321
                            const std::vector<int64_t>& padding,
322
323
324
325
326
                            Op& op,
                            float pad_val = 0)
    {
        if(padding[0] != padding[2] || padding[1] != padding[3])
        {
327
328
329
            ins = prog.add_instruction(
                op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
                ins);
330
331
332
333
334
335
336
337
        }
        else
        {
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
        }
    }

338
339
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
340
    {
kahmed10's avatar
kahmed10 committed
341
342
343
344
345
346
347
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
348
        {
kahmed10's avatar
kahmed10 committed
349
350
351
352
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
353
        }
kahmed10's avatar
kahmed10 committed
354
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
355
        {
kahmed10's avatar
kahmed10 committed
356
357
358
359
360
361
362
363
364
365
366
367
368
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
369
        }
kahmed10's avatar
kahmed10 committed
370
371
372
373
374
375
376
377
378
379
380
381
382

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
383
384
    }

Shucai Xiao's avatar
Shucai Xiao committed
385
    template <class Op>
386
387
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
388
    {
389
        int64_t axis = 1;
390
        if(contains(info.attributes, "axis"))
391
        {
392
            axis = parse_value(info.attributes.at("axis")).at<int>();
393
394
        }

395
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
396
397
    }

Shucai Xiao's avatar
Shucai Xiao committed
398
    template <class Op>
399
400
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
401
    {
402
        int64_t axis = 0;
403
        if(contains(info.attributes, "axis"))
404
        {
405
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
406
407
        }

Shucai Xiao's avatar
Shucai Xiao committed
408
        int keep_dims = 1;
409
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
410
        {
411
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
412
413
        }

Shucai Xiao's avatar
Shucai Xiao committed
414
        if(keep_dims == 0)
415
        {
416
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
417
            return prog.add_instruction(op::squeeze{{axis}}, ins);
418
419
420
        }
        else
        {
421
            return prog.add_instruction(Op{axis}, std::move(args));
422
        }
423
424
    }

425
426
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
427
                                               node_info info,
428
                                               Op& op,
429
430
431
432
                                               std::array<std::size_t, 2> k_lens,
                                               std::array<std::size_t, 2> dilation,
                                               const std::vector<std::size_t>& in_lens,
                                               float value = 0.0f)
433
    {
434
        if(!contains(info.attributes, "auto_pad"))
435
436
437
438
        {
            return ins;
        }

439
        auto auto_pad = info.attributes["auto_pad"].s();
440
441
        if(auto_pad.find("SAME") != std::string::npos)
        {
442
443
444
445
446
447
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            std::vector<int64_t> padding(in_lens.size());
            calculate_padding(
                0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
            calculate_padding(
                1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
448

449
            check_asym_padding(ins, padding, op, value);
450
451
452
453
454
        }

        return ins;
    }

455
    template <class Op>
Paul's avatar
Paul committed
456
    instruction_ref
457
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
458
    {
459
        Op op;
460
461
        auto l0      = args[0];
        auto weights = args[1];
462
        std::vector<int64_t> padding;
463
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
464
        {
465
            if(contains(info.attributes, "auto_pad"))
466
            {
467
468
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
469
                {
470
471
                    MIGRAPHX_THROW(
                        "PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
472
                }
473
            }
474
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
475
            if(padding.size() != 4)
476
            {
477
                MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
478
            }
479
            check_asym_padding(l0, padding, op);
Paul's avatar
Paul committed
480
        }
481
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
482
        {
483
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
484
        }
485
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
486
        {
487
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
488
        }
489
        if(contains(info.attributes, "auto_pad"))
490
        {
491
            auto s = info.attributes["auto_pad"].s();
wsttiger's avatar
fixes  
wsttiger committed
492
            if(s.find("SAME") != std::string::npos)
493
            {
494
495
496
497
498
499
                op.padding_mode                 = op::padding_mode_t::same;
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
500
                padding.resize(input_dims.size());
501
502
503
504
505
506
                calculate_padding(
                    0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(
                    1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                check_asym_padding(l0, padding, op);
507
            }
508
509
510
511
512

            auto in_lens                      = args[0]->get_shape().lens();
            auto weight_lens                  = args[1]->get_shape().lens();
            std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
            l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
513
        }
514
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
515
        {
516
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
517
        }
kahmed10's avatar
kahmed10 committed
518
519
520
521
522

        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

523
524
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
525
526
527
528
529
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
        bool asymm_padding = false;
530
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
531
        {
532
            if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
533
            {
534
535
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
536
537
538
539
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
            }
540
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
541
542
543
544
545
546
547
548
549
550
551
552
553
554
            if(padding.size() != 4)
            {
                MIGRAPHX_THROW("padding should have 4 values");
            }
            if(padding[0] != padding[2] || padding[1] != padding[3])
            {
                asymm_padding = true;
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
555
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
556
        {
557
            copy(info.attributes["strides"].ints(), op.stride.begin());
kahmed10's avatar
kahmed10 committed
558
        }
559
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
560
        {
561
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
562
        }
563
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
564
        {
565
566
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
567
568
569
570
571
572
573
574
575
576
            {
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

577
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
578
        {
579
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
        }

        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
        std::vector<int64_t> curr_shape{dims[2], dims[3]};
        if(asymm_padding)
        {
            op::slice slice_op;
            slice_op.axes   = {0, 1, 2, 3};
            slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
            slice_op.ends   = {
                dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};

            l1 = prog.add_instruction(slice_op, l1);
        }

596
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
597
598
        {
            std::vector<int64_t> output_padding;
599
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
600
601
602
603
            output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
            l1             = prog.add_instruction(op::pad{output_padding}, l1);
        }

604
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
605
606
        {
            std::vector<int64_t> output_shape;
607
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
            dims       = to_int64_vector(l1->get_shape().lens());
            curr_shape = {dims[2], dims[3]};
            if(curr_shape != output_shape)
            {
                std::vector<int64_t> target_padding = {0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       output_shape[0] - curr_shape[0],
                                                       output_shape[1] - curr_shape[1]};
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
625
    }
Paul's avatar
Paul committed
626

627
628
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
629
    {
Khalique's avatar
Khalique committed
630
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
631
        auto l0 = args[0];
Khalique's avatar
Khalique committed
632
        if(starts_with(name, "Global"))
633
        {
Khalique's avatar
Khalique committed
634
635
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
636
        }
637

638
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
639
        {
640
            if(contains(info.attributes, "auto_pad"))
641
            {
642
                auto s = info.attributes["auto_pad"].s();
643
644
645
646
647
648
649
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

650
            std::vector<std::int64_t> padding;
651
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
652
            if(padding.size() != 4)
653
            {
654
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
655
            }
656
657
658
659
            float pad_val = 0;
            if(op.mode == "max")
                pad_val = std::numeric_limits<float>::lowest();
            check_asym_padding(l0, padding, op, pad_val);
Paul's avatar
Paul committed
660
        }
661

662
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
663
        {
664
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
665
        }
666
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
667
        {
668
            copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
Paul's avatar
Paul committed
669
        }
670

671
        if(contains(info.attributes, "auto_pad"))
672
        {
673
674
675
676
677
678
            auto s = info.attributes["auto_pad"].s();
            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }

679
            auto in_lens = args[0]->get_shape().lens();
680
681
682
683
684
685
686
687
            float val    = 0.0f;
            // MaxPool
            if(op.mode == "max")
            {
                val = std::numeric_limits<float>::lowest();
            }

            l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
688
689
        }

690
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
691
692
    }

Paul's avatar
Paul committed
693
    instruction_ref
694
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
695
    {
696
        op::reshape op;
Paul's avatar
Paul committed
697
698
        if(args.size() == 1)
        {
699
            literal s = parse_value(info.attributes.at("shape"));
700
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
701
702
703
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
704
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
705
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
706
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
707
        }
708

Shucai Xiao's avatar
Shucai Xiao committed
709
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
710
711
    }

Paul's avatar
Paul committed
712
    instruction_ref
713
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
714
    {
715
        int64_t axis = 1;
716
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
717
        {
718
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
719
        }
720
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
721
722
    }

723
    instruction_ref
724
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
725
726
    {
        op::squeeze op;
727
        literal s = parse_value(info.attributes.at("axes"));
728
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
729
        return prog.add_instruction(op, make_contiguous(args[0]));
730
731
732
    }

    instruction_ref
733
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
734
735
    {
        op::unsqueeze op;
736
        literal s = parse_value(info.attributes.at("axes"));
737
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
738
        return prog.add_instruction(op, make_contiguous(args[0]));
739
740
    }

Scott Thornton's avatar
Scott Thornton committed
741
    instruction_ref
742
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
743
    {
Shucai Xiao's avatar
Shucai Xiao committed
744
        // change to hande axis to be negative values
745
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
746
747
748
749
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

750
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
751
752
753
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
754

755
    instruction_ref
756
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
757
    {
758
        int axis = 0;
759
        if(contains(info.attributes, "axis"))
760
        {
761
            axis = parse_value(info.attributes.at("axis")).at<int>();
762
        }
763

764
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
765
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
766
767
    }

768
    instruction_ref
769
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
770
771
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
794
        {
795
            literal s = parse_value(info.attributes.at("axes"));
796
797
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
798
799

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
800
        {
Shucai Xiao's avatar
Shucai Xiao committed
801
802
803
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
804
        }
Shucai Xiao's avatar
Shucai Xiao committed
805
        else if(contains(info.attributes, "ends"))
806
        {
807
808
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
809
        }
Shucai Xiao's avatar
Shucai Xiao committed
810
811
812
813
814
815
816
817

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
818
        {
819
            literal s = parse_value(info.attributes.at("starts"));
820
821
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
822

kahmed10's avatar
kahmed10 committed
823
824
825
826
827
828
829
        if(op.axes.empty())
        {
            std::vector<int64_t> axes(args[0]->get_shape().lens().size());
            std::iota(axes.begin(), axes.end(), int64_t{0});
            op.axes = axes;
        }

830
831
832
        return prog.add_instruction(op, args[0]);
    }

833
834
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
835
    {
836
        literal v = parse_value(info.attributes.at("value"));
837
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
838
        if(v.get_shape().elements() == 0)
839
840
841
842
        {
            return prog.add_literal(literal{});
        }

843
        auto dim_size = info.attributes.at("value").t().dims_size();
844
845
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
846
        {
847
            migraphx::shape scalar_shape{v.get_shape().type()};
848
849
850
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
851
852
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
853

Paul's avatar
Paul committed
854
    instruction_ref
855
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
856
857
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
858
        float beta  = 1.0f;
Paul's avatar
Paul committed
859
860
        bool transa = false;
        bool transb = false;
861
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
862
        {
863
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
864
        }
865
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
866
        {
867
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
868
        }
869
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
870
        {
871
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
872
        }
873
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
874
        {
875
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
876
        }
877
878
879
880
881
882

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

883
884
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
885
886
        if(args.size() == 3)
        {
887
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
888
            {
Shucai Xiao's avatar
Shucai Xiao committed
889
                auto out_lens   = l1->get_shape().lens();
890
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
891
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
892
893
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
894
                {
895
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
896
                }
897
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
898
            }
Paul's avatar
Paul committed
899
        }
900
901

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
902
903
    }

904
    template <class Op>
905
    instruction_ref
906
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
907
    {
Shucai Xiao's avatar
Shucai Xiao committed
908
909
        auto l0      = args[0];
        auto l1      = args[1];
910
911
912
913
914
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
915
        if(l0_lens.size() == 1)
916
917
918
919
920
921
922
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
923
        if(l1_lens.size() == 1)
924
925
926
927
928
929
930
931
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
932
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
933
934
935
936
937
938
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
939
            l0_broadcasted_lens = output_lens;
940
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
941
            l1_broadcasted_lens = output_lens;
942
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
943
            if(l0_lens != l0_broadcasted_lens)
944
945
946
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
947
            if(l1_lens != l1_broadcasted_lens)
948
949
950
951
952
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

953
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
954
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
955
        if(is_a_prepended)
956
957
958
959
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
960
        if(is_b_appended)
961
962
963
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
964

965
966
967
        return dot_res;
    }

968
    instruction_ref
969
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
970
    {
Scott Thornton's avatar
Scott Thornton committed
971
972
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
973
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
974
        if(contains(info.attributes, "epsilon"))
975
        {
976
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
977
        }
978
        if(contains(info.attributes, "momentum"))
979
        {
980
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
981
        }
982
        if(contains(info.attributes, "spatial"))
983
        {
984
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
985
986
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
987
        }
Paul's avatar
Paul committed
988
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
989
        return prog.add_instruction(op, std::move(args));
990
991
    }

992
993
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
994
995
996
997
998
999
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
1000
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
1001
        {
1002
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1027
1028
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1029
    {
Khalique's avatar
Khalique committed
1030
        float alpha = 0.01; // default alpha val for leaky relu
1031
        if(contains(info.attributes, "alpha"))
1032
        {
1033
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1034
1035
1036
1037
1038
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1039
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1040
1041
    {
        float alpha = 1.0; // default alpha val for elu
1042
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1043
        {
1044
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1045
1046
1047
1048
1049
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1050
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1051
1052
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1053
1054
1055
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1056
1057
1058
1059
1060
1061
1062
1063
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1064
1065
1066
1067
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1068
1069
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1070
1071
1072
    {
        float scale = 1.0;
        std::vector<float> bias{};
1073
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1074
        {
1075
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1076
1077
        }

1078
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1079
        {
1080
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1081
1082
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1083
1084
1085
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1086

Shucai Xiao's avatar
Shucai Xiao committed
1087
1088
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1089

1090
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1091
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1092
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1093
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1094
    }
Khalique's avatar
Khalique committed
1095

Khalique's avatar
Khalique committed
1096
    instruction_ref
1097
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1098
1099
    {
        std::vector<int64_t> perm{};
1100
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1101
        {
1102
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1103
1104
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1105
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1106
1107
    }

1108
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1109
1110
    {
        std::vector<int64_t> pads{};
1111
1112
1113
1114
1115
1116
1117
        if(args.size() >= 2)
        {
            auto pad_arg = args.at(1)->eval();
            check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
            pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
        }
        else if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1118
        {
1119
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1120
1121
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1122
1123
1124
1125
1126
        else
        {
            MIGRAPHX_THROW("PARSE_PAD: pad must be available");
        }

1127
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1128
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1129
1130
1131
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149

        float value = 0.0f;
        // third input is the value
        if(args.size() == 3)
        {
            auto val_ins = args.at(2);
            if(!val_ins->can_eval())
            {
                MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
            }
            auto val_arg = val_ins->eval();
            if(val_arg.get_shape().elements() != 1)
            {
                MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
            }
            value = val_arg.at<float>();
        }
        else if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1150
        {
1151
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1152
        }
1153

1154
        if(contains(info.attributes, "mode"))
Khalique's avatar
Khalique committed
1155
        {
1156
            auto mode = info.attributes.at("mode").s();
Khalique's avatar
Khalique committed
1157
            if(mode != "constant")
1158
1159
1160
            {
                MIGRAPHX_THROW("PARSE_PAD: migraphx currently only supports constant padding");
            }
Khalique's avatar
Khalique committed
1161
1162
1163
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1164
1165
1166
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1167
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1168
1169
    {
        if(args.size() != 1)
1170
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1183
1184
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1185
1186
1187
1188
1189
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1190
        if(contains(info.attributes, "dtype"))
1191
        {
1192
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1193
        }
Shucai Xiao's avatar
Shucai Xiao committed
1194
        shape::type_t type = get_type(dtype);
1195

1196
        if(contains(info.attributes, "input_as_shape"))
1197
        {
1198
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1199
1200
        }

1201
        if(contains(info.attributes, "value"))
1202
        {
1203
            value = parse_value(info.attributes.at("value")).at<float>();
1204
1205
        }

1206
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1207
        {
1208
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1209
1210
        }

1211
1212
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1213
            if(args.size() != 1)
1214
            {
1215
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1216
1217
            }

1218
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1219
            {
1220
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1221
                               "at the same time");
1222
1223
            }

1224
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1225
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1226

1227
1228
1229
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1230
1231
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1232
1233
1234
        }
        else if(input_as_shape == 0)
        {
1235
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1236
            {
1237
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1238
1239
            }

1240
            literal ls = parse_value(info.attributes.at("shape"));
1241
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1242
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1243
            migraphx::shape s{type, dims};
1244
1245
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1246
1247
1248
        }
        else
        {
1249
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1250
1251
1252
        }
    }

1253
1254
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1255
1256
    {
        literal l_val{};
1257
        if(contains(info.attributes, "value"))
1258
        {
1259
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1260
            if(l_val.get_shape().elements() != 1)
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1272

Shucai Xiao's avatar
Shucai Xiao committed
1273
        if(args.empty())
1274
        {
Shucai Xiao's avatar
Shucai Xiao committed
1275
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1276
1277
1278
        }
        else
        {
1279
1280
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1281
            if(args[0]->get_shape().elements() == 0)
1282
            {
1283
                s = migraphx::shape{type, {1}, {0}};
1284
            }
1285
1286
1287
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1288
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1289

1290
1291
1292
1293
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1294

Shucai Xiao's avatar
Shucai Xiao committed
1295
            literal l_out{};
1296
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1297
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1298
                // l_val contains only one element
1299
                std::vector<val_type> out_vec(s.elements(), val.front());
1300
1301
1302
1303
1304
1305
1306
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1307
    instruction_ref
1308
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1309
    {
Shucai Xiao's avatar
Shucai Xiao committed
1310
        auto in_lens             = args[0]->get_shape().lens();
1311
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1312
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1313
1314
1315
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1316
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1317
1318
    }

Shucai Xiao's avatar
Shucai Xiao committed
1319
    std::vector<instruction_ref>
1320
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1321
1322
    {
        migraphx::shape input_shape = args[0]->get_shape();
1323
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1324

1325
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1326
        {
1327
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1328
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1329
1330
1331
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1332
1333
1334
1335
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1336
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1337
        {
1338
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1339
1340
        }

1341
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1342
1343
        if(direction == "bidirectional")
        {
1344
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1345
1346
1347
        }
        else if(direction == "reverse")
        {
1348
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1349
1350
        }

1351
        std::vector<std::string> vec_names{"tanh"};
1352
        if(contains(info.attributes, "activations"))
1353
        {
1354
            auto names = info.attributes.at("activations").strings();
1355
            vec_names.clear();
1356
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1357
1358
1359
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1360
1361
        }

1362
1363
1364
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1365
        if(name_it != vec_names.end())
1366
1367
1368
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1369

Shucai Xiao's avatar
Shucai Xiao committed
1370
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1371
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1372
        // if only one actv function is provided, we use it in both
1373
        // forward and reverse direction
1374
        if(dirct == op::rnn_direction::bidirectional)
1375
        {
Shucai Xiao's avatar
Shucai Xiao committed
1376
            if(vec_names.size() == 1)
1377
1378
1379
1380
1381
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1382
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1383
1384
1385
1386
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1387

Shucai Xiao's avatar
Shucai Xiao committed
1388
1389
        // To be added later
        float clip = 0.0;
1390
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1391
        {
1392
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1393
1394
        }

1395
1396
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1397
        if(args.size() < 6)
1398
1399
1400
1401
1402
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1403
1404
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1405
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1406

1407
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1408
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1409

Shucai Xiao's avatar
Shucai Xiao committed
1410
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1411
1412
    }

1413
    std::vector<instruction_ref>
1414
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1415
1416
1417
1418
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1419
        if(contains(info.attributes, "hidden_size"))
1420
        {
1421
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1422
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1423
1424
1425
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1426
1427
1428
1429
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1430
        if(contains(info.attributes, "direction"))
1431
        {
1432
            direction = info.attributes.at("direction").s();
1433
1434
        }

1435
        op::rnn_direction dirct = op::rnn_direction::forward;
1436
1437
        if(direction == "bidirectional")
        {
1438
            dirct = op::rnn_direction::bidirectional;
1439
1440
1441
        }
        else if(direction == "reverse")
        {
1442
            dirct = op::rnn_direction::reverse;
1443
1444
        }

1445
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1446
        if(contains(info.attributes, "activations"))
1447
        {
1448
            auto names = info.attributes.at("activations").strings();
1449
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1450
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1451
1452
1453
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1454
1455
        }

1456
        // need 4 activation functions
1457
        if(dirct == op::rnn_direction::bidirectional)
1458
        {
Shucai Xiao's avatar
Shucai Xiao committed
1459
            // 4 activation functions are used in the bidirectional
1460
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1461
1462
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1463
1464
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1465
1466
1467
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1468
            if(vec_names.size() == 1)
1469
            {
1470
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1471
            }
1472
            else if(vec_names.size() == 2)
1473
            {
1474
1475
1476
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1477
            }
1478
            else if(vec_names.size() == 3)
1479
            {
1480
                vec_names.push_back(vec_names.at(2));
1481
1482
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1483
        else
1484
        {
1485
            if(vec_names.size() == 1)
1486
            {
1487
                vec_names.push_back(vec_names.at(0));
1488
1489
1490
            }
        }

1491
1492
1493
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1494
        if(name_it != vec_names.end())
1495
1496
1497
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1498

Shucai Xiao's avatar
Shucai Xiao committed
1499
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1500
1501
1502
1503
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1504
1505

        float clip = 0.0;
1506
        if(contains(info.attributes, "clip"))
1507
        {
1508
            clip = parse_value(info.attributes.at("clip")).at<float>();
1509
1510
1511
        }

        int linear_before_reset = 0;
1512
        if(contains(info.attributes, "linear_before_reset"))
1513
        {
1514
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1515
1516
        }

Shucai Xiao's avatar
Shucai Xiao committed
1517
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1518
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1519
1520
1521
1522
1523
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1524
1525
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1526
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1527
            std::move(args));
1528
1529

        // second output for last gru output
1530
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1531

Shucai Xiao's avatar
Shucai Xiao committed
1532
        return {hidden_states, last_output};
1533
1534
    }

Shucai Xiao's avatar
Shucai Xiao committed
1535
    std::vector<instruction_ref>
1536
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1537
1538
1539
1540
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1541
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1542
        {
1543
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1544
1545
1546
1547
1548
1549
1550
1551
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1552
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1553
        {
1554
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1555
1556
        }

Shucai Xiao's avatar
Shucai Xiao committed
1557
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1558
1559
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1560
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1561
1562
1563
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1564
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1565
        }
Shucai Xiao's avatar
Shucai Xiao committed
1566
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1567
        {
Shucai Xiao's avatar
Shucai Xiao committed
1568
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1569
1570
1571
1572
1573
1574
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1575
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1576
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1577
        {
1578
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1579
1580
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1581
1582
1583
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1584
1585
1586
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1587
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1588
1589
1590
1591
1592
1593
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1594
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1595
1596
1597
1598
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1599
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1600
1601
1602
1603
1604
1605
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1606
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1607
1608
1609

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1610
1611
1612
1613
1614
1615
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1616
1617
1618
1619
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1620
1621
1622
1623
1624
1625
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1626
1627
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1628
1629
1630
1631
1632
1633
1634
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1635
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1636

Shucai Xiao's avatar
Shucai Xiao committed
1637
1638
1639
1640
1641
1642
1643
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1644
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1645

Shucai Xiao's avatar
Shucai Xiao committed
1646
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1647
1648
1649
1650
1651
1652
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1653
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1654
1655
1656

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1657
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1658
1659
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1660
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1661
1662
1663
            }
        }

1664
1665
1666
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1667
        if(name_it != vec_names.end())
1668
1669
1670
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1671
1672

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1673
1674
1675
1676
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1677
1678

        float clip = 0.0;
1679
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1680
        {
1681
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1682
1683
1684
        }

        int input_forget = 0;
1685
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
1686
        {
1687
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1688
1689
1690
1691
1692
1693
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1694
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1695
1696
1697
1698
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1699
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1700
1701

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1702
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1703
1704
1705
1706
1707
1708

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1709

Shucai Xiao's avatar
Shucai Xiao committed
1710
    template <class T>
1711
1712
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1713
1714
1715
1716
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1717
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1718
        std::iota(axes.begin(), axes.end(), 0);
1719
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
1720
1721
        {
            axes.clear();
1722
            auto&& attr_axes = info.attributes["axes"].ints();
1723
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1724
1725
1726
        }

        int keep_dims = 1;
1727
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
1728
        {
1729
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1730
1731
1732
1733
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1734
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1735
1736
1737
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1738
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1739
            return prog.add_instruction(op::squeeze{axes}, ins);
1740
1741
        }
    }
1742

Shucai Xiao's avatar
Shucai Xiao committed
1743
    instruction_ref
1744
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1745
1746
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
1747
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1748
1749
1750
    }

    instruction_ref
1751
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1752
1753
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1754
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1755
1756
1757
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

1758
1759
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1760
    {
1761
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1762
1763
1764
        return prog.add_instruction(op::log{}, sum_ins);
    }

1765
1766
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1767
1768
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
1769
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1770
1771
1772
        return prog.add_instruction(op::log{}, sum_ins);
    }

1773
1774
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1775
1776
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1777
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1778
1779
    }

Shucai Xiao's avatar
Shucai Xiao committed
1780
    instruction_ref
1781
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
1782
    {
1783
        if(!contains(info.attributes, "to"))
1784
1785
1786
1787
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

1788
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
1789
1790
1791
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1792

1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

kahmed10's avatar
kahmed10 committed
1846
1847
1848
1849
    instruction_ref
    parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        migraphx::argument depth_arg = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1850
        check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported");
kahmed10's avatar
kahmed10 committed
1851
1852
1853
        size_t depth = depth_arg.at<size_t>();

        int64_t axis = -1;
Shucai Xiao's avatar
Shucai Xiao committed
1854
1855
1856
1857
        if(contains(info.attributes, "axis"))
        {
            axis = info.attributes.at("axis").i();
        }
kahmed10's avatar
kahmed10 committed
1858

Shucai Xiao's avatar
Shucai Xiao committed
1859
        std::vector<float> depth_input(depth * depth, 0.0f);
kahmed10's avatar
kahmed10 committed
1860
1861
        for(int i = 0; i < depth; i++)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1862
            depth_input[depth * i + i] = 1.0f;
kahmed10's avatar
kahmed10 committed
1863
1864
        }

Shucai Xiao's avatar
Shucai Xiao committed
1865
1866
1867
1868
1869
1870
1871
1872
        auto type = args[2]->get_shape().type();
        shape s{type, {depth, depth}};
        auto l_val      = prog.add_literal({s, depth_input});
        auto gather_out = prog.add_instruction(op::gather{0}, {l_val, args[0]});

        // Finally, we need a transpose to move the inner most dim to the axis dim
        int n_rank = gather_out->get_shape().lens().size();
        if(axis < -n_rank or axis >= n_rank)
kahmed10's avatar
kahmed10 committed
1873
        {
Shucai Xiao's avatar
Shucai Xiao committed
1874
            MIGRAPHX_THROW("PARSE_ONEHOT: axis out of range");
kahmed10's avatar
kahmed10 committed
1875
        }
Shucai Xiao's avatar
Shucai Xiao committed
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
        std::vector<int64_t> perm(n_rank - 1);
        std::iota(perm.begin(), perm.end(), 0);
        perm.insert(perm.begin() + tuned_axis, n_rank - 1);
        auto tr_out = prog.add_instruction(op::transpose{perm}, gather_out);
        auto lens   = tr_out->get_shape().lens();

        auto off_val       = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
        auto on_val        = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
        auto diff          = prog.add_instruction(op::sub{}, on_val, off_val);
        auto unsq_off_val  = prog.add_instruction(op::multibroadcast{lens}, off_val);
        auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff);
        auto l_mul         = prog.add_instruction(op::mul{}, tr_out, unsq_diff_val);
        return prog.add_instruction(op::add{}, l_mul, unsq_off_val);
kahmed10's avatar
kahmed10 committed
1890
1891
    }

kahmed10's avatar
kahmed10 committed
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
    instruction_ref
    parse_tile(const std::string&, const node_info&, std::vector<instruction_ref> args)
    {
        migraphx::argument arg_s = args[1]->eval();
        check_arg_empty(arg_s, "PARSE_TILE: dynamic shape is not supported");
        std::vector<std::int64_t> repeats;
        arg_s.visit([&](auto input) { repeats.assign(input.begin(), input.end()); });

        auto l0 = args[0];
        for(int i = 0; i < repeats.size(); i++)
        {
            auto l1 = l0;
            for(int j = 1; j < repeats[i]; j++)
            {
                l0 = prog.add_instruction(op::concat{i}, l0, l1);
            }
        }
        return l0;
    }

Paul's avatar
Paul committed
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1924
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1925
1926
1927
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
1944
1945
    void parse_graph(const onnx::GraphProto& graph)
    {
1946
        for(auto&& f : graph.initializer())
1947
1948
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1949
1950
1951
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1952
1953
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1954
            {
1955
1956
1957
1958
1959
1960
1961
                std::vector<std::size_t> dims;
                if(map_input_dims.count(name) > 0)
                {
                    dims = map_input_dims.at(name);
                }

                shape s            = parse_type(input.type(), dims);
1962
1963
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1964
        }
1965
1966

        for(auto&& node : graph.node())
Paul's avatar
Paul committed
1967
        {
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(input.empty())
                {
                    this->parse_undefined(input);
                }
                if(instructions.count(input) == 0)
                {
                    MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
                                   "\" is unavailable due to unordered nodes!");
                }
                args.push_back(instructions.at(input));
            }

            std::vector<instruction_ref> result;
            std::size_t output_num = static_cast<std::size_t>(node.output().size());
            if(ops.count(node.op_type()) == 0)
            {
1987
1988
1989
1990
                if(skip_unknown_operators)
                    result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
                else
                    MIGRAPHX_THROW("Unknown operator: " + node.op_type());
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
            }
            else
            {
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
            }

            output_num = std::min<std::size_t>(output_num, result.size());
            std::transform(node.output().begin(),
                           node.output().begin() + output_num,
                           result.begin(),
                           std::inserter(instructions, instructions.end()),
                           [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
2003
        }
Shucai Xiao's avatar
Shucai Xiao committed
2004

2005
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
2006
        auto prog_output = graph.output();
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
2027
2028
    }

Shucai Xiao's avatar
Shucai Xiao committed
2029
    void parse_undefined(const std::string& name)
2030
    {
Shucai Xiao's avatar
Shucai Xiao committed
2031
        auto ins           = prog.add_instruction(op::undefined{});
2032
2033
2034
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
2059
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
2060
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
2061
2062
2063
2064
2065
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
2066
2067
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
2068
2069
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
2070
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
2071
2072
2073
2074
2075
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
2076
2077
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
2078
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
2079
2080
            switch(t.data_type())
            {
2081
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
2082
2083
2084
2085
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
2086
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
2087
2088
2089
2090
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
2091
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
2092
2093
2094
2095
2096
2097
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
2098
2099
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
2100
            MIGRAPHX_THROW("Invalid tensor type");
2101
        }
Paul's avatar
Paul committed
2102
2103
2104
2105
2106
2107
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
2108
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
2109
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
2110
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2111
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
2112
2113
2114
2115
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
2116
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2117
        {
Khalique's avatar
Khalique committed
2118
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2119
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2120
2121
2122
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2123
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2124
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2125
        }
Paul's avatar
Paul committed
2126
2127
2128
2129
2130
2131
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2132
2133
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2134
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2135
2136
    }

Khalique's avatar
Khalique committed
2137
    static literal
2138
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2139
    {
Khalique's avatar
Khalique committed
2140
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2141
        if(dims.empty())
2142
            return literal{{shape_type}, data};
2143
2144
2145
        return literal{{shape_type, dims}, data};
    }

2146
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2147
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2148
2149
    {
        if(dims.empty())
2150
            return literal{{shape_type}, data.begin(), data.end()};
2151
        return literal{{shape_type, dims}, data.begin(), data.end()};
2152
2153
    }

2154
    shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims)
Paul's avatar
Paul committed
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2165
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2166
2167
2168
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
2169
        case onnx::TensorProto::UINT8: shape_type = shape::uint8_type; break;
Paul's avatar
Paul committed
2170
2171
2172
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2173
2174
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2175
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2176
        }
2177
2178
2179
2180
2181
2182

        if(!input_dims.empty())
        {
            return {shape_type, input_dims};
        }

Paul's avatar
Paul committed
2183
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2184
        auto&& tensor_dims = t.tensor_type().shape().dim();
2185
2186
2187
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2188
2189
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2190
                           {
2191
                               if(static_cast<int>(d.dim_value()) <= 0)
2192
2193
2194
                               {
                                   return default_dim_value;
                               }
2195
                               return d.dim_value();
2196
                           }
2197
2198
2199
2200
                           else
                           {
                               return default_dim_value;
                           }
2201
                       });
2202

2203
2204
2205
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2206
2207
        return {shape_type, dims};
    }
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2230
2231
2232

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2233
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2234
2235
2236
2237
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2238
2239
};

Paul Fultz II's avatar
Paul Fultz II committed
2240
template <class... Ts>
2241
program parse_onnx_from(const onnx_options& options, Ts&&... xs)
Paul's avatar
Paul committed
2242
2243
{
    onnx_parser parser;
2244
2245
2246
    parser.map_input_dims         = options.map_input_dims;
    parser.default_dim_value      = options.default_dim_value;
    parser.skip_unknown_operators = options.skip_unknown_operators;
2247

2248
    if(options.print_program_on_error)
Paul's avatar
Paul committed
2249
    {
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
        // Log the program when it can't be parsed
        try
        {
            parser.parse_from(std::forward<Ts>(xs)...);
        }
        catch(...)
        {
            std::cerr << parser.prog << std::endl;
            throw;
        }
Paul's avatar
Paul committed
2260
    }
2261
    else
Paul's avatar
Paul committed
2262
    {
2263
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2264
2265
2266
2267
    }
    return std::move(parser.prog);
}

2268
program parse_onnx(const std::string& name, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2269
2270
2271
2272
2273
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

2274
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2275
2276
2277
2278
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

2279
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2280
2281
2282
2283
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2284
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2285
} // namespace migraphx