onnx.cpp 80.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
27
28
29
30
31
32
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
33
    using op_func =
34
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
35
36
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
37
38
39
    program prog            = program();
    bool is_pytorch         = false;
    unsigned int batch_size = 1;
Paul's avatar
Paul committed
40
41

    std::unordered_map<std::string, op_func> ops;
42
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
43
44
45

    onnx_parser()
    {
46
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
47
        add_generic_op("Abs", op::abs{});
48
49
50
51
52
53
54
55
56
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
57
        add_generic_op("Erf", op::erf{});
58
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
59
        add_generic_op("Dropout", op::identity{});
60
61
        add_generic_op("Log", op::log{});
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
62
        add_generic_op("Identity", op::identity{});
kahmed10's avatar
kahmed10 committed
63
        add_generic_op("Reciprocal", op::recip{});
64
65
66
67
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
68
        add_generic_op("Sin", op::sin{});
69
        add_generic_op("Sinh", op::sinh{});
70
        add_generic_op("Sqrt", op::sqrt{});
71
72
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
73

Khalique's avatar
Khalique committed
74
75
76
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
77
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
78
        add_binary_op("PRelu", op::prelu{});
79
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
80

Khalique's avatar
Khalique committed
81
82
83
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
84

85
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
86
87
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
88
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
89
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
90
        add_mem_op("Clip", &onnx_parser::parse_clip);
91
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
92
        add_mem_op("Constant", &onnx_parser::parse_constant);
93
94
95
96
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
97
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
98
99
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
100
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
101
        add_mem_op("Gather", &onnx_parser::parse_gather);
Paul's avatar
Paul committed
102
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
103
104
105
106
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
107
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
108
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
109
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
110
111
112
113
        add_mem_op("LRN", &onnx_parser::parse_lrn);
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
kahmed10's avatar
kahmed10 committed
114
        add_mem_op("OneHot", &onnx_parser::parse_onehot);
Shucai Xiao's avatar
Shucai Xiao committed
115
116
117
118
119
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
120
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
121
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
122
123
124
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
125
126
127
128
129
130
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Pad", &onnx_parser::parse_pad);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
131
        add_mem_op("Split", &onnx_parser::parse_split);
132
133
134
135
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
136
137
138
139
140
141
142

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
143
144
145
146
147
148
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
149
150
151
152
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
161
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
162
163
164
165
166
167
168
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
169
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
170
171
172
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
173

174
    template <class T>
Khalique's avatar
Khalique committed
175
    void add_binary_op(std::string name, T x)
176
    {
177
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
178
            if(args.size() != 2)
Paul's avatar
Paul committed
179
                MIGRAPHX_THROW("binary operators should have 2 operands");
180
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
181
            {
182
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
183
184
                if(broadcasted != 0)
                {
185
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
186
187
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
188
189
                    return prog.add_instruction(x, args[0], l);
                }
190
                return prog.add_instruction(x, args);
191
            }
Paul's avatar
Paul committed
192
            else
193
            {
Khalique's avatar
Khalique committed
194
                return add_broadcastable_binary_op(args[0], args[1], x);
195
196
197
198
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
199
200
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
201
202
203
204
205
206
207
208
209
210
211
212
213
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
214
        if(s0.size() > s1.size())
215
216
217
218
219
220
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
221
222
223
224
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
225
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
226
                           if(a != b and a != 1 and b != 1)
227
                           {
Shucai Xiao's avatar
Shucai Xiao committed
228
229
230
231
232
233
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
234
235
236
237

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
238
239
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
240
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
241
242
243
244
245
246
247
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
248
249
250
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
251
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
252
253
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
254
255
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
256
            auto out_lens = compute_broadcasted_lens(s0, s1);
257
258
259
260
261
262
263
264
265

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
266
267
268
269
270
271
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
272
273
    }

Paul's avatar
Paul committed
274
    template <class T>
Paul's avatar
Paul committed
275
276
    void add_generic_op(std::string name, T x)
    {
277
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
278
279
280
281
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
282
    template <class T>
Khalique's avatar
Khalique committed
283
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
284
    {
285
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
286
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
287
288
289
290
291
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
292
        });
Khalique's avatar
Khalique committed
293
294
    }

kahmed10's avatar
kahmed10 committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

314
315
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
316
                            const std::vector<int64_t>& padding,
317
318
319
320
321
                            Op& op,
                            float pad_val = 0)
    {
        if(padding[0] != padding[2] || padding[1] != padding[3])
        {
322
323
324
            ins = prog.add_instruction(
                op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
                ins);
325
326
327
328
329
330
331
332
        }
        else
        {
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
        }
    }

333
334
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
335
    {
kahmed10's avatar
kahmed10 committed
336
337
338
339
340
341
342
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
343
        {
kahmed10's avatar
kahmed10 committed
344
345
346
347
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
348
        }
kahmed10's avatar
kahmed10 committed
349
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
350
        {
kahmed10's avatar
kahmed10 committed
351
352
353
354
355
356
357
358
359
360
361
362
363
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
364
        }
kahmed10's avatar
kahmed10 committed
365
366
367
368
369
370
371
372
373
374
375
376
377

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
378
379
    }

Shucai Xiao's avatar
Shucai Xiao committed
380
    template <class Op>
381
382
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
383
    {
384
        int64_t axis = 1;
385
        if(contains(info.attributes, "axis"))
386
        {
387
            axis = parse_value(info.attributes.at("axis")).at<int>();
388
389
        }

390
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
391
392
    }

Shucai Xiao's avatar
Shucai Xiao committed
393
    template <class Op>
394
395
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
396
    {
397
        int64_t axis = 0;
398
        if(contains(info.attributes, "axis"))
399
        {
400
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
401
402
        }

Shucai Xiao's avatar
Shucai Xiao committed
403
        int keep_dims = 1;
404
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
405
        {
406
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
407
408
        }

Shucai Xiao's avatar
Shucai Xiao committed
409
        if(keep_dims == 0)
410
        {
411
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
412
            return prog.add_instruction(op::squeeze{{axis}}, ins);
413
414
415
        }
        else
        {
416
            return prog.add_instruction(Op{axis}, std::move(args));
417
        }
418
419
    }

420
421
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
422
                                               node_info info,
423
                                               Op& op,
424
425
426
427
                                               std::array<std::size_t, 2> k_lens,
                                               std::array<std::size_t, 2> dilation,
                                               const std::vector<std::size_t>& in_lens,
                                               float value = 0.0f)
428
    {
429
        if(!contains(info.attributes, "auto_pad"))
430
431
432
433
        {
            return ins;
        }

434
        auto auto_pad = info.attributes["auto_pad"].s();
435
436
        if(auto_pad.find("SAME") != std::string::npos)
        {
437
438
439
440
441
442
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            std::vector<int64_t> padding(in_lens.size());
            calculate_padding(
                0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
            calculate_padding(
                1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
443

444
            check_asym_padding(ins, padding, op, value);
445
446
447
448
449
        }

        return ins;
    }

450
    template <class Op>
Paul's avatar
Paul committed
451
    instruction_ref
452
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
453
    {
454
        Op op;
455
456
        auto l0      = args[0];
        auto weights = args[1];
457
        std::vector<int64_t> padding;
458
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
459
        {
460
            if(contains(info.attributes, "auto_pad"))
461
            {
462
463
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
464
                {
465
466
                    MIGRAPHX_THROW(
                        "PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
467
                }
468
            }
469
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
470
            if(padding.size() != 4)
471
            {
472
                MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
473
            }
474
            check_asym_padding(l0, padding, op);
Paul's avatar
Paul committed
475
        }
476
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
477
        {
478
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
479
        }
480
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
481
        {
482
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
483
        }
484
        if(contains(info.attributes, "auto_pad"))
485
        {
486
            auto s = info.attributes["auto_pad"].s();
wsttiger's avatar
fixes  
wsttiger committed
487
            if(s.find("SAME") != std::string::npos)
488
            {
489
490
491
492
493
494
                op.padding_mode                 = op::padding_mode_t::same;
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
495
                padding.resize(input_dims.size());
496
497
498
499
500
501
                calculate_padding(
                    0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(
                    1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                check_asym_padding(l0, padding, op);
502
            }
503
504
505
506
507

            auto in_lens                      = args[0]->get_shape().lens();
            auto weight_lens                  = args[1]->get_shape().lens();
            std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
            l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
508
        }
509
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
510
        {
511
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
512
        }
kahmed10's avatar
kahmed10 committed
513
514
515
516
517

        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

518
519
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
520
521
522
523
524
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
        bool asymm_padding = false;
525
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
526
        {
527
            if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
528
            {
529
530
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
531
532
533
534
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
            }
535
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
            if(padding.size() != 4)
            {
                MIGRAPHX_THROW("padding should have 4 values");
            }
            if(padding[0] != padding[2] || padding[1] != padding[3])
            {
                asymm_padding = true;
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
550
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
551
        {
552
            copy(info.attributes["strides"].ints(), op.stride.begin());
kahmed10's avatar
kahmed10 committed
553
        }
554
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
555
        {
556
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
557
        }
558
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
559
        {
560
561
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
562
563
564
565
566
567
568
569
570
571
            {
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

572
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
573
        {
574
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        }

        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
        std::vector<int64_t> curr_shape{dims[2], dims[3]};
        if(asymm_padding)
        {
            op::slice slice_op;
            slice_op.axes   = {0, 1, 2, 3};
            slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
            slice_op.ends   = {
                dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};

            l1 = prog.add_instruction(slice_op, l1);
        }

591
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
592
593
        {
            std::vector<int64_t> output_padding;
594
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
595
596
597
598
            output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
            l1             = prog.add_instruction(op::pad{output_padding}, l1);
        }

599
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
600
601
        {
            std::vector<int64_t> output_shape;
602
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
            dims       = to_int64_vector(l1->get_shape().lens());
            curr_shape = {dims[2], dims[3]};
            if(curr_shape != output_shape)
            {
                std::vector<int64_t> target_padding = {0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       output_shape[0] - curr_shape[0],
                                                       output_shape[1] - curr_shape[1]};
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
620
    }
Paul's avatar
Paul committed
621

622
623
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
624
    {
Khalique's avatar
Khalique committed
625
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
626
        auto l0 = args[0];
Khalique's avatar
Khalique committed
627
        if(starts_with(name, "Global"))
628
        {
Khalique's avatar
Khalique committed
629
630
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
631
        }
632

633
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
634
        {
635
            if(contains(info.attributes, "auto_pad"))
636
            {
637
                auto s = info.attributes["auto_pad"].s();
638
639
640
641
642
643
644
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

645
            std::vector<std::int64_t> padding;
646
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
647
            if(padding.size() != 4)
648
            {
649
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
650
            }
651
652
653
654
            float pad_val = 0;
            if(op.mode == "max")
                pad_val = std::numeric_limits<float>::lowest();
            check_asym_padding(l0, padding, op, pad_val);
Paul's avatar
Paul committed
655
        }
656

657
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
658
        {
659
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
660
        }
661
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
662
        {
663
            copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
Paul's avatar
Paul committed
664
        }
665

666
        if(contains(info.attributes, "auto_pad"))
667
        {
668
669
670
671
672
673
            auto s = info.attributes["auto_pad"].s();
            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }

674
            auto in_lens = args[0]->get_shape().lens();
675
676
677
678
679
680
681
682
            float val    = 0.0f;
            // MaxPool
            if(op.mode == "max")
            {
                val = std::numeric_limits<float>::lowest();
            }

            l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
683
684
        }

685
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
686
687
    }

Paul's avatar
Paul committed
688
    instruction_ref
689
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
690
    {
691
        op::reshape op;
Paul's avatar
Paul committed
692
693
        if(args.size() == 1)
        {
694
            literal s = parse_value(info.attributes.at("shape"));
695
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
696
697
698
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
699
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
700
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
701
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
702
        }
703

Shucai Xiao's avatar
Shucai Xiao committed
704
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
705
706
    }

Paul's avatar
Paul committed
707
    instruction_ref
708
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
709
    {
710
        int64_t axis = 1;
711
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
712
        {
713
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
714
        }
715
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
716
717
    }

718
    instruction_ref
719
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
720
721
    {
        op::squeeze op;
722
        literal s = parse_value(info.attributes.at("axes"));
723
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
724
        return prog.add_instruction(op, make_contiguous(args[0]));
725
726
727
    }

    instruction_ref
728
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
729
730
    {
        op::unsqueeze op;
731
        literal s = parse_value(info.attributes.at("axes"));
732
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
733
        return prog.add_instruction(op, make_contiguous(args[0]));
734
735
    }

Scott Thornton's avatar
Scott Thornton committed
736
    instruction_ref
737
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
738
    {
Shucai Xiao's avatar
Shucai Xiao committed
739
        // change to hande axis to be negative values
740
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
741
742
743
744
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

745
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
746
747
748
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
749

750
    instruction_ref
751
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
752
    {
753
        int axis = 0;
754
        if(contains(info.attributes, "axis"))
755
        {
756
            axis = parse_value(info.attributes.at("axis")).at<int>();
757
        }
758

759
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
760
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
761
762
    }

763
    instruction_ref
764
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
765
766
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
789
        {
790
            literal s = parse_value(info.attributes.at("axes"));
791
792
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
793
794

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
795
        {
Shucai Xiao's avatar
Shucai Xiao committed
796
797
798
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
799
        }
Shucai Xiao's avatar
Shucai Xiao committed
800
        else if(contains(info.attributes, "ends"))
801
        {
802
803
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
804
        }
Shucai Xiao's avatar
Shucai Xiao committed
805
806
807
808
809
810
811
812

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
813
        {
814
            literal s = parse_value(info.attributes.at("starts"));
815
816
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
817

kahmed10's avatar
kahmed10 committed
818
819
820
821
822
823
824
        if(op.axes.empty())
        {
            std::vector<int64_t> axes(args[0]->get_shape().lens().size());
            std::iota(axes.begin(), axes.end(), int64_t{0});
            op.axes = axes;
        }

825
826
827
        return prog.add_instruction(op, args[0]);
    }

828
829
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
830
    {
831
        literal v = parse_value(info.attributes.at("value"));
832
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
833
        if(v.get_shape().elements() == 0)
834
835
836
837
        {
            return prog.add_literal(literal{});
        }

838
        auto dim_size = info.attributes.at("value").t().dims_size();
839
840
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
841
        {
842
            migraphx::shape scalar_shape{v.get_shape().type()};
843
844
845
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
846
847
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
848

Paul's avatar
Paul committed
849
    instruction_ref
850
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
851
852
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
853
        float beta  = 1.0f;
Paul's avatar
Paul committed
854
855
        bool transa = false;
        bool transb = false;
856
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
857
        {
858
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
859
        }
860
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
861
        {
862
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
863
        }
864
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
865
        {
866
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
867
        }
868
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
869
        {
870
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
871
        }
872
873
874
875
876
877

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

878
879
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
880
881
        if(args.size() == 3)
        {
882
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
883
            {
Shucai Xiao's avatar
Shucai Xiao committed
884
                auto out_lens   = l1->get_shape().lens();
885
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
886
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
887
888
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
889
                {
890
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
891
                }
892
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
893
            }
Paul's avatar
Paul committed
894
        }
895
896

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
897
898
    }

899
    template <class Op>
900
    instruction_ref
901
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
902
    {
Shucai Xiao's avatar
Shucai Xiao committed
903
904
        auto l0      = args[0];
        auto l1      = args[1];
905
906
907
908
909
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
910
        if(l0_lens.size() == 1)
911
912
913
914
915
916
917
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
918
        if(l1_lens.size() == 1)
919
920
921
922
923
924
925
926
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
927
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
928
929
930
931
932
933
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
934
            l0_broadcasted_lens = output_lens;
935
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
936
            l1_broadcasted_lens = output_lens;
937
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
938
            if(l0_lens != l0_broadcasted_lens)
939
940
941
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
942
            if(l1_lens != l1_broadcasted_lens)
943
944
945
946
947
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

948
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
949
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
950
        if(is_a_prepended)
951
952
953
954
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
955
        if(is_b_appended)
956
957
958
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
959

960
961
962
        return dot_res;
    }

963
    instruction_ref
964
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
965
    {
Scott Thornton's avatar
Scott Thornton committed
966
967
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
968
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
969
        if(contains(info.attributes, "epsilon"))
970
        {
971
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
972
        }
973
        if(contains(info.attributes, "momentum"))
974
        {
975
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
976
        }
977
        if(contains(info.attributes, "spatial"))
978
        {
979
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
980
981
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
982
        }
Paul's avatar
Paul committed
983
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
984
        return prog.add_instruction(op, std::move(args));
985
986
    }

987
988
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
989
990
991
992
993
994
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
995
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
996
        {
997
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1022
1023
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1024
    {
Khalique's avatar
Khalique committed
1025
        float alpha = 0.01; // default alpha val for leaky relu
1026
        if(contains(info.attributes, "alpha"))
1027
        {
1028
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1029
1030
1031
1032
1033
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1034
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1035
1036
    {
        float alpha = 1.0; // default alpha val for elu
1037
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1038
        {
1039
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1040
1041
1042
1043
1044
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1045
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1046
1047
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1048
1049
1050
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1051
1052
1053
1054
1055
1056
1057
1058
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1059
1060
1061
1062
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1063
1064
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1065
1066
1067
    {
        float scale = 1.0;
        std::vector<float> bias{};
1068
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1069
        {
1070
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1071
1072
        }

1073
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1074
        {
1075
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1076
1077
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1078
1079
1080
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1081

Shucai Xiao's avatar
Shucai Xiao committed
1082
1083
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1084

1085
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1086
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1087
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1088
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1089
    }
Khalique's avatar
Khalique committed
1090

Khalique's avatar
Khalique committed
1091
    instruction_ref
1092
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1093
1094
    {
        std::vector<int64_t> perm{};
1095
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1096
        {
1097
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1098
1099
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1100
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1101
1102
    }

1103
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1104
1105
1106
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
1107
        if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1108
        {
1109
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1110
1111
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1112
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1113
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1114
1115
1116
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1117
        if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1118
        {
1119
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1120
        }
1121
        if(contains(info.attributes, "mode"))
Khalique's avatar
Khalique committed
1122
        {
1123
            auto mode = info.attributes.at("mode").s();
Khalique's avatar
Khalique committed
1124
1125
1126
1127
1128
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1129
1130
1131
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1132
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1133
1134
    {
        if(args.size() != 1)
1135
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1148
1149
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1150
1151
1152
1153
1154
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1155
        if(contains(info.attributes, "dtype"))
1156
        {
1157
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1158
        }
Shucai Xiao's avatar
Shucai Xiao committed
1159
        shape::type_t type = get_type(dtype);
1160

1161
        if(contains(info.attributes, "input_as_shape"))
1162
        {
1163
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1164
1165
        }

1166
        if(contains(info.attributes, "value"))
1167
        {
1168
            value = parse_value(info.attributes.at("value")).at<float>();
1169
1170
        }

1171
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1172
        {
1173
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1174
1175
        }

1176
1177
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1178
            if(args.size() != 1)
1179
            {
1180
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1181
1182
            }

1183
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1184
            {
1185
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1186
                               "at the same time");
1187
1188
            }

1189
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1190
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1191

1192
1193
1194
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1195
1196
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1197
1198
1199
        }
        else if(input_as_shape == 0)
        {
1200
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1201
            {
1202
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1203
1204
            }

1205
            literal ls = parse_value(info.attributes.at("shape"));
1206
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1207
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1208
            migraphx::shape s{type, dims};
1209
1210
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1211
1212
1213
        }
        else
        {
1214
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1215
1216
1217
        }
    }

1218
1219
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1220
1221
    {
        literal l_val{};
1222
        if(contains(info.attributes, "value"))
1223
        {
1224
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1225
            if(l_val.get_shape().elements() != 1)
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1237

Shucai Xiao's avatar
Shucai Xiao committed
1238
        if(args.empty())
1239
        {
Shucai Xiao's avatar
Shucai Xiao committed
1240
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1241
1242
1243
        }
        else
        {
1244
1245
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1246
            if(args[0]->get_shape().elements() == 0)
1247
            {
1248
                s = migraphx::shape{type, {1}, {0}};
1249
            }
1250
1251
1252
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1253
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1254

1255
1256
1257
1258
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1259

Shucai Xiao's avatar
Shucai Xiao committed
1260
            literal l_out{};
1261
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1262
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1263
                // l_val contains only one element
1264
                std::vector<val_type> out_vec(s.elements(), val.front());
1265
1266
1267
1268
1269
1270
1271
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1272
    instruction_ref
1273
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1274
    {
Shucai Xiao's avatar
Shucai Xiao committed
1275
        auto in_lens             = args[0]->get_shape().lens();
1276
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1277
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1278
1279
1280
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1281
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1282
1283
    }

Shucai Xiao's avatar
Shucai Xiao committed
1284
    std::vector<instruction_ref>
1285
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1286
1287
    {
        migraphx::shape input_shape = args[0]->get_shape();
1288
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1289

1290
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1291
        {
1292
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1293
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1294
1295
1296
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1297
1298
1299
1300
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1301
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1302
        {
1303
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1304
1305
        }

1306
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1307
1308
        if(direction == "bidirectional")
        {
1309
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1310
1311
1312
        }
        else if(direction == "reverse")
        {
1313
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1314
1315
        }

1316
        std::vector<std::string> vec_names{"tanh"};
1317
        if(contains(info.attributes, "activations"))
1318
        {
1319
            auto names = info.attributes.at("activations").strings();
1320
            vec_names.clear();
1321
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1322
1323
1324
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1325
1326
        }

1327
1328
1329
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1330
        if(name_it != vec_names.end())
1331
1332
1333
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1334

Shucai Xiao's avatar
Shucai Xiao committed
1335
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1336
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1337
        // if only one actv function is provided, we use it in both
1338
        // forward and reverse direction
1339
        if(dirct == op::rnn_direction::bidirectional)
1340
        {
Shucai Xiao's avatar
Shucai Xiao committed
1341
            if(vec_names.size() == 1)
1342
1343
1344
1345
1346
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1347
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1348
1349
1350
1351
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1352

Shucai Xiao's avatar
Shucai Xiao committed
1353
1354
        // To be added later
        float clip = 0.0;
1355
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1356
        {
1357
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1358
1359
        }

1360
1361
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1362
        if(args.size() < 6)
1363
1364
1365
1366
1367
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1368
1369
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1370
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1371

1372
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1373
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1374

Shucai Xiao's avatar
Shucai Xiao committed
1375
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1376
1377
    }

1378
    std::vector<instruction_ref>
1379
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1380
1381
1382
1383
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1384
        if(contains(info.attributes, "hidden_size"))
1385
        {
1386
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1387
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1388
1389
1390
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1391
1392
1393
1394
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1395
        if(contains(info.attributes, "direction"))
1396
        {
1397
            direction = info.attributes.at("direction").s();
1398
1399
        }

1400
        op::rnn_direction dirct = op::rnn_direction::forward;
1401
1402
        if(direction == "bidirectional")
        {
1403
            dirct = op::rnn_direction::bidirectional;
1404
1405
1406
        }
        else if(direction == "reverse")
        {
1407
            dirct = op::rnn_direction::reverse;
1408
1409
        }

1410
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1411
        if(contains(info.attributes, "activations"))
1412
        {
1413
            auto names = info.attributes.at("activations").strings();
1414
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1415
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1416
1417
1418
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1419
1420
        }

1421
        // need 4 activation functions
1422
        if(dirct == op::rnn_direction::bidirectional)
1423
        {
Shucai Xiao's avatar
Shucai Xiao committed
1424
            // 4 activation functions are used in the bidirectional
1425
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1426
1427
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1428
1429
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1430
1431
1432
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1433
            if(vec_names.size() == 1)
1434
            {
1435
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1436
            }
1437
            else if(vec_names.size() == 2)
1438
            {
1439
1440
1441
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1442
            }
1443
            else if(vec_names.size() == 3)
1444
            {
1445
                vec_names.push_back(vec_names.at(2));
1446
1447
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1448
        else
1449
        {
1450
            if(vec_names.size() == 1)
1451
            {
1452
                vec_names.push_back(vec_names.at(0));
1453
1454
1455
            }
        }

1456
1457
1458
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1459
        if(name_it != vec_names.end())
1460
1461
1462
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1463

Shucai Xiao's avatar
Shucai Xiao committed
1464
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1465
1466
1467
1468
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1469
1470

        float clip = 0.0;
1471
        if(contains(info.attributes, "clip"))
1472
        {
1473
            clip = parse_value(info.attributes.at("clip")).at<float>();
1474
1475
1476
        }

        int linear_before_reset = 0;
1477
        if(contains(info.attributes, "linear_before_reset"))
1478
        {
1479
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1480
1481
        }

Shucai Xiao's avatar
Shucai Xiao committed
1482
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1483
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1484
1485
1486
1487
1488
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1489
1490
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1491
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1492
            std::move(args));
1493
1494

        // second output for last gru output
1495
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1496

Shucai Xiao's avatar
Shucai Xiao committed
1497
        return {hidden_states, last_output};
1498
1499
    }

Shucai Xiao's avatar
Shucai Xiao committed
1500
    std::vector<instruction_ref>
1501
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1502
1503
1504
1505
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1506
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1507
        {
1508
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1509
1510
1511
1512
1513
1514
1515
1516
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1517
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1518
        {
1519
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1520
1521
        }

Shucai Xiao's avatar
Shucai Xiao committed
1522
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1523
1524
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1525
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1526
1527
1528
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1529
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1530
        }
Shucai Xiao's avatar
Shucai Xiao committed
1531
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1532
        {
Shucai Xiao's avatar
Shucai Xiao committed
1533
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1534
1535
1536
1537
1538
1539
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1540
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1541
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1542
        {
1543
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1544
1545
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1546
1547
1548
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1549
1550
1551
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1552
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1553
1554
1555
1556
1557
1558
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1559
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1560
1561
1562
1563
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1564
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1565
1566
1567
1568
1569
1570
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1571
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1572
1573
1574

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1575
1576
1577
1578
1579
1580
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1581
1582
1583
1584
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1585
1586
1587
1588
1589
1590
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1591
1592
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1593
1594
1595
1596
1597
1598
1599
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1600
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1601

Shucai Xiao's avatar
Shucai Xiao committed
1602
1603
1604
1605
1606
1607
1608
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1609
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1610

Shucai Xiao's avatar
Shucai Xiao committed
1611
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1612
1613
1614
1615
1616
1617
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1618
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1619
1620
1621

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1622
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1623
1624
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1625
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1626
1627
1628
            }
        }

1629
1630
1631
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1632
        if(name_it != vec_names.end())
1633
1634
1635
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1636
1637

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1638
1639
1640
1641
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1642
1643

        float clip = 0.0;
1644
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1645
        {
1646
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1647
1648
1649
        }

        int input_forget = 0;
1650
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
1651
        {
1652
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1653
1654
1655
1656
1657
1658
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1659
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1660
1661
1662
1663
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1664
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1665
1666

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1667
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1668
1669
1670
1671
1672
1673

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1674

Shucai Xiao's avatar
Shucai Xiao committed
1675
    template <class T>
1676
1677
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1678
1679
1680
1681
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1682
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1683
        std::iota(axes.begin(), axes.end(), 0);
1684
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
1685
1686
        {
            axes.clear();
1687
            auto&& attr_axes = info.attributes["axes"].ints();
1688
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1689
1690
1691
        }

        int keep_dims = 1;
1692
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
1693
        {
1694
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1695
1696
1697
1698
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1699
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1700
1701
1702
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1703
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1704
            return prog.add_instruction(op::squeeze{axes}, ins);
1705
1706
        }
    }
1707

Shucai Xiao's avatar
Shucai Xiao committed
1708
    instruction_ref
1709
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1710
1711
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
1712
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1713
1714
1715
    }

    instruction_ref
1716
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1717
1718
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1719
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1720
1721
1722
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

1723
1724
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1725
    {
1726
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1727
1728
1729
        return prog.add_instruction(op::log{}, sum_ins);
    }

1730
1731
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1732
1733
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
1734
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1735
1736
1737
        return prog.add_instruction(op::log{}, sum_ins);
    }

1738
1739
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1740
1741
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1742
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1743
1744
    }

Shucai Xiao's avatar
Shucai Xiao committed
1745
    instruction_ref
1746
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
1747
    {
1748
        if(!contains(info.attributes, "to"))
1749
1750
1751
1752
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

1753
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
1754
1755
1756
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1757

1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

kahmed10's avatar
kahmed10 committed
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
    instruction_ref
    parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        migraphx::argument depth_arg = args[1]->eval();
        check_arg_empty(depth_arg, "ONEHOT: depth - dynamic shape not supported");
        size_t depth = depth_arg.at<size_t>();

        int64_t axis = -1;
        std::vector<float> on_off_vals;

        migraphx::argument values_arg = args[2]->eval();
        check_arg_empty(values_arg, "ONEHOT: values - dynamic shape not supported");
        values_arg.visit([&](auto v) { copy(v, std::back_inserter(on_off_vals)); });
        float off_value = on_off_vals[0];
        float on_value  = on_off_vals[1];

        std::vector<float> depth_input(depth * depth, off_value);
        for(int i = 0; i < depth; i++)
        {
            depth_input[depth * i + i] = on_value;
        }

        if(contains(info.attributes, "axis"))
            axis = info.attributes.at("axis").i();
        if(axis == -1)
        {
            shape s{shape::float_type, {depth, depth}};
            auto l0 = prog.add_literal({s, depth_input});
            return prog.add_instruction(op::gather{0}, {l0, args[0]});
        }
        MIGRAPHX_THROW("ONEHOT: MIGraphX does not support axis != -1");
    }

Paul's avatar
Paul committed
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1856
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1857
1858
1859
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
1876
1877
    void parse_graph(const onnx::GraphProto& graph)
    {
1878
        for(auto&& f : graph.initializer())
1879
1880
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1881
1882
1883
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1884
1885
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1886
1887
            {
                // TODO: Get shape of input parameter
1888
                shape s            = parse_type(input.type(), batch_size);
1889
1890
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1891
        }
1892
1893

        for(auto&& node : graph.node())
Paul's avatar
Paul committed
1894
        {
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(input.empty())
                {
                    this->parse_undefined(input);
                }
                if(instructions.count(input) == 0)
                {
                    MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
                                   "\" is unavailable due to unordered nodes!");
                }
                args.push_back(instructions.at(input));
            }

            std::vector<instruction_ref> result;
            std::size_t output_num = static_cast<std::size_t>(node.output().size());
            if(ops.count(node.op_type()) == 0)
            {
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
            }
            else
            {
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
            }

            output_num = std::min<std::size_t>(output_num, result.size());
            std::transform(node.output().begin(),
                           node.output().begin() + output_num,
                           result.begin(),
                           std::inserter(instructions, instructions.end()),
                           [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
1927
        }
Shucai Xiao's avatar
Shucai Xiao committed
1928

1929
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
1930
        auto prog_output = graph.output();
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
1951
1952
    }

Shucai Xiao's avatar
Shucai Xiao committed
1953
    void parse_undefined(const std::string& name)
1954
    {
Shucai Xiao's avatar
Shucai Xiao committed
1955
        auto ins           = prog.add_instruction(op::undefined{});
1956
1957
1958
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
1983
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1984
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
1985
1986
1987
1988
1989
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
1990
1991
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
1992
1993
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
1994
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
1995
1996
1997
1998
1999
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
2000
2001
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
2002
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
2003
2004
            switch(t.data_type())
            {
2005
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
2006
2007
2008
2009
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
2010
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
2011
2012
2013
2014
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
2015
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
2016
2017
2018
2019
2020
2021
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
2022
2023
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
2024
            MIGRAPHX_THROW("Invalid tensor type");
2025
        }
Paul's avatar
Paul committed
2026
2027
2028
2029
2030
2031
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
2032
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
2033
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
2034
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2035
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
2036
2037
2038
2039
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
2040
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2041
        {
Khalique's avatar
Khalique committed
2042
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2043
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2044
2045
2046
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2047
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2048
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2049
        }
Paul's avatar
Paul committed
2050
2051
2052
2053
2054
2055
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2056
2057
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2058
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2059
2060
    }

Khalique's avatar
Khalique committed
2061
    static literal
2062
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2063
    {
Khalique's avatar
Khalique committed
2064
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2065
        if(dims.empty())
2066
            return literal{{shape_type}, data};
2067
2068
2069
        return literal{{shape_type, dims}, data};
    }

2070
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2071
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2072
2073
    {
        if(dims.empty())
2074
            return literal{{shape_type}, data.begin(), data.end()};
2075
        return literal{{shape_type, dims}, data.begin(), data.end()};
2076
2077
    }

2078
    static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
Paul's avatar
Paul committed
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2089
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2090
2091
2092
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
2093
2094
2095
2096
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2097
2098
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2099
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2100
2101
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2102
        auto&& tensor_dims = t.tensor_type().shape().dim();
2103
2104
2105
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2106
2107
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2108
                           {
2109
2110
2111
                               if(static_cast<int>(d.dim_value()) <= 0)
                                   return batch_size;
                               return d.dim_value();
2112
                           }
2113
                           return batch_size;
2114
                       });
2115
2116
2117
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2118
2119
        return {shape_type, dims};
    }
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2142
2143
2144

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2145
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2146
2147
2148
2149
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2150
2151
};

Paul Fultz II's avatar
Paul Fultz II committed
2152
2153
template <class... Ts>
program parse_onnx_from(onnx_options options, Ts&&... xs)
Paul's avatar
Paul committed
2154
2155
{
    onnx_parser parser;
2156
    parser.batch_size = options.batch_size;
Paul's avatar
Paul committed
2157
2158
2159
2160
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
Paul Fultz II's avatar
Paul Fultz II committed
2161
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2162
2163
2164
2165
2166
2167
2168
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
Paul Fultz II's avatar
Paul Fultz II committed
2169
    parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2170
2171
2172
2173
#endif
    return std::move(parser.prog);
}

Paul Fultz II's avatar
Paul Fultz II committed
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
program parse_onnx(const std::string& name, onnx_options options)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

program parse_onnx_buffer(const std::string& buffer, onnx_options options)
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options)
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2190
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2191
} // namespace migraphx