onnx.cpp 80.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
27
28
29
30
31
32
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
33
    using op_func =
34
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
35
36
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
37
38
39
    program prog            = program();
    bool is_pytorch         = false;
    unsigned int batch_size = 1;
Paul's avatar
Paul committed
40
41

    std::unordered_map<std::string, op_func> ops;
42
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
43
44
45

    onnx_parser()
    {
46
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
47
        add_generic_op("Abs", op::abs{});
48
49
50
51
52
53
54
55
56
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
57
        add_generic_op("Erf", op::erf{});
58
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
59
        add_generic_op("Dropout", op::identity{});
60
61
        add_generic_op("Log", op::log{});
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
62
        add_generic_op("Identity", op::identity{});
63
64
65
66
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
67
        add_generic_op("Sin", op::sin{});
68
        add_generic_op("Sinh", op::sinh{});
69
        add_generic_op("Sqrt", op::sqrt{});
70
71
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
72

Khalique's avatar
Khalique committed
73
74
75
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
76
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
77
        add_binary_op("PRelu", op::prelu{});
78
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
79

Khalique's avatar
Khalique committed
80
81
82
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
83

84
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
85
86
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
87
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
88
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
89
        add_mem_op("Clip", &onnx_parser::parse_clip);
90
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
91
        add_mem_op("Constant", &onnx_parser::parse_constant);
92
93
94
95
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
96
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
97
98
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
99
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
100
        add_mem_op("Gather", &onnx_parser::parse_gather);
Paul's avatar
Paul committed
101
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
102
103
104
105
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
106
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
107
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
108
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
109
110
111
112
        add_mem_op("LRN", &onnx_parser::parse_lrn);
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
Shucai Xiao's avatar
Shucai Xiao committed
113
114
115
116
117
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
118
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
119
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
120
121
122
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
123
124
125
126
127
128
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Pad", &onnx_parser::parse_pad);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
129
        add_mem_op("Split", &onnx_parser::parse_split);
130
131
132
133
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
134
135
136
137
138
139
140

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
141
142
143
144
145
146
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
147
148
149
150
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
151
152
153
154
155
156
157
158
159
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
160
161
162
163
164
165
166
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
167
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
168
169
170
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
171

172
    template <class T>
Khalique's avatar
Khalique committed
173
    void add_binary_op(std::string name, T x)
174
    {
175
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
176
            if(args.size() != 2)
Paul's avatar
Paul committed
177
                MIGRAPHX_THROW("binary operators should have 2 operands");
178
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
179
            {
180
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
181
182
                if(broadcasted != 0)
                {
183
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
184
185
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
186
187
                    return prog.add_instruction(x, args[0], l);
                }
188
                return prog.add_instruction(x, args);
189
            }
Paul's avatar
Paul committed
190
            else
191
            {
Khalique's avatar
Khalique committed
192
                return add_broadcastable_binary_op(args[0], args[1], x);
193
194
195
196
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
197
198
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
199
200
201
202
203
204
205
206
207
208
209
210
211
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
212
        if(s0.size() > s1.size())
213
214
215
216
217
218
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
219
220
221
222
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
223
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
224
                           if(a != b and a != 1 and b != 1)
225
                           {
Shucai Xiao's avatar
Shucai Xiao committed
226
227
228
229
230
231
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
232
233
234
235

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
236
237
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
238
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
239
240
241
242
243
244
245
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
246
247
248
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
249
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
250
251
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
252
253
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
254
            auto out_lens = compute_broadcasted_lens(s0, s1);
255
256
257
258
259
260
261
262
263

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
264
265
266
267
268
269
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
270
271
    }

Paul's avatar
Paul committed
272
    template <class T>
Paul's avatar
Paul committed
273
274
    void add_generic_op(std::string name, T x)
    {
275
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
276
277
278
279
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
280
    template <class T>
Khalique's avatar
Khalique committed
281
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
282
    {
283
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
284
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
285
286
287
288
289
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
290
        });
Khalique's avatar
Khalique committed
291
292
    }

kahmed10's avatar
kahmed10 committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

312
313
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
314
                            const std::vector<int64_t>& padding,
315
316
317
318
319
                            Op& op,
                            float pad_val = 0)
    {
        if(padding[0] != padding[2] || padding[1] != padding[3])
        {
320
321
322
            ins = prog.add_instruction(
                op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
                ins);
323
324
325
326
327
328
329
330
        }
        else
        {
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
        }
    }

331
332
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
333
    {
kahmed10's avatar
kahmed10 committed
334
335
336
337
338
339
340
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
341
        {
kahmed10's avatar
kahmed10 committed
342
343
344
345
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
346
        }
kahmed10's avatar
kahmed10 committed
347
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
348
        {
kahmed10's avatar
kahmed10 committed
349
350
351
352
353
354
355
356
357
358
359
360
361
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
362
        }
kahmed10's avatar
kahmed10 committed
363
364
365
366
367
368
369
370
371
372
373
374
375

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
376
377
    }

Shucai Xiao's avatar
Shucai Xiao committed
378
    template <class Op>
379
380
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
381
    {
382
        int64_t axis = 1;
383
        if(contains(info.attributes, "axis"))
384
        {
385
            axis = parse_value(info.attributes.at("axis")).at<int>();
386
387
        }

388
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
389
390
    }

Shucai Xiao's avatar
Shucai Xiao committed
391
    template <class Op>
392
393
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
394
    {
395
        int64_t axis = 0;
396
        if(contains(info.attributes, "axis"))
397
        {
398
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
399
400
        }

Shucai Xiao's avatar
Shucai Xiao committed
401
        int keep_dims = 1;
402
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
403
        {
404
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
405
406
        }

Shucai Xiao's avatar
Shucai Xiao committed
407
        if(keep_dims == 0)
408
        {
409
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
410
            return prog.add_instruction(op::squeeze{{axis}}, ins);
411
412
413
        }
        else
        {
414
            return prog.add_instruction(Op{axis}, std::move(args));
415
        }
416
417
    }

418
419
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
420
                                               node_info info,
421
                                               Op& op,
422
423
424
425
                                               std::array<std::size_t, 2> k_lens,
                                               std::array<std::size_t, 2> dilation,
                                               const std::vector<std::size_t>& in_lens,
                                               float value = 0.0f)
426
    {
427
        if(!contains(info.attributes, "auto_pad"))
428
429
430
431
        {
            return ins;
        }

432
        auto auto_pad = info.attributes["auto_pad"].s();
433
434
        if(auto_pad.find("SAME") != std::string::npos)
        {
435
436
437
438
439
440
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            std::vector<int64_t> padding(in_lens.size());
            calculate_padding(
                0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
            calculate_padding(
                1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
441

442
            check_asym_padding(ins, padding, op, value);
443
444
445
446
447
        }

        return ins;
    }

448
    template <class Op>
Paul's avatar
Paul committed
449
    instruction_ref
450
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
451
    {
452
        Op op;
453
454
        auto l0      = args[0];
        auto weights = args[1];
455
        std::vector<int64_t> padding;
456
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
457
        {
458
            if(contains(info.attributes, "auto_pad"))
459
            {
460
461
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
462
                {
463
464
                    MIGRAPHX_THROW(
                        "PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
465
                }
466
            }
467
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
468
            if(padding.size() != 4)
469
            {
470
                MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
471
            }
472
            check_asym_padding(l0, padding, op);
Paul's avatar
Paul committed
473
        }
474
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
475
        {
476
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
477
        }
478
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
479
        {
480
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
481
        }
482
        if(contains(info.attributes, "auto_pad"))
483
        {
484
            auto s = info.attributes["auto_pad"].s();
wsttiger's avatar
fixes  
wsttiger committed
485
            if(s.find("SAME") != std::string::npos)
486
            {
487
488
489
490
491
492
                op.padding_mode                 = op::padding_mode_t::same;
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
493
                padding.resize(input_dims.size());
494
495
496
497
498
499
                calculate_padding(
                    0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(
                    1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                check_asym_padding(l0, padding, op);
500
            }
501
502
503
504
505

            auto in_lens                      = args[0]->get_shape().lens();
            auto weight_lens                  = args[1]->get_shape().lens();
            std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
            l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
506
        }
507
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
508
        {
509
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
510
        }
kahmed10's avatar
kahmed10 committed
511
512
513
514
515

        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

516
517
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
518
519
520
521
522
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
        bool asymm_padding = false;
523
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
524
        {
525
            if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
526
            {
527
528
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
529
530
531
532
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
            }
533
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
            if(padding.size() != 4)
            {
                MIGRAPHX_THROW("padding should have 4 values");
            }
            if(padding[0] != padding[2] || padding[1] != padding[3])
            {
                asymm_padding = true;
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
548
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
549
        {
550
            copy(info.attributes["strides"].ints(), op.stride.begin());
kahmed10's avatar
kahmed10 committed
551
        }
552
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
553
        {
554
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
555
        }
556
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
557
        {
558
559
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
560
561
562
563
564
565
566
567
568
569
            {
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

570
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
571
        {
572
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        }

        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
        std::vector<int64_t> curr_shape{dims[2], dims[3]};
        if(asymm_padding)
        {
            op::slice slice_op;
            slice_op.axes   = {0, 1, 2, 3};
            slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
            slice_op.ends   = {
                dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};

            l1 = prog.add_instruction(slice_op, l1);
        }

589
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
590
591
        {
            std::vector<int64_t> output_padding;
592
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
593
594
595
596
            output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
            l1             = prog.add_instruction(op::pad{output_padding}, l1);
        }

597
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
598
599
        {
            std::vector<int64_t> output_shape;
600
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
            dims       = to_int64_vector(l1->get_shape().lens());
            curr_shape = {dims[2], dims[3]};
            if(curr_shape != output_shape)
            {
                std::vector<int64_t> target_padding = {0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       output_shape[0] - curr_shape[0],
                                                       output_shape[1] - curr_shape[1]};
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
618
    }
Paul's avatar
Paul committed
619

620
621
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
622
    {
Khalique's avatar
Khalique committed
623
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
624
        auto l0 = args[0];
Khalique's avatar
Khalique committed
625
        if(starts_with(name, "Global"))
626
        {
Khalique's avatar
Khalique committed
627
628
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
629
        }
630

631
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
632
        {
633
            if(contains(info.attributes, "auto_pad"))
634
            {
635
                auto s = info.attributes["auto_pad"].s();
636
637
638
639
640
641
642
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

643
            std::vector<std::int64_t> padding;
644
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
645
            if(padding.size() != 4)
646
            {
647
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
648
            }
649
650
651
652
            float pad_val = 0;
            if(op.mode == "max")
                pad_val = std::numeric_limits<float>::lowest();
            check_asym_padding(l0, padding, op, pad_val);
Paul's avatar
Paul committed
653
        }
654

655
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
656
        {
657
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
658
        }
659
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
660
        {
661
            copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
Paul's avatar
Paul committed
662
        }
663

664
        if(contains(info.attributes, "auto_pad"))
665
        {
666
667
668
669
670
671
            auto s = info.attributes["auto_pad"].s();
            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }

672
            auto in_lens = args[0]->get_shape().lens();
673
674
675
676
677
678
679
680
            float val    = 0.0f;
            // MaxPool
            if(op.mode == "max")
            {
                val = std::numeric_limits<float>::lowest();
            }

            l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
681
682
        }

683
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
684
685
    }

Paul's avatar
Paul committed
686
    instruction_ref
687
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
688
    {
689
        op::reshape op;
Paul's avatar
Paul committed
690
691
        if(args.size() == 1)
        {
692
            literal s = parse_value(info.attributes.at("shape"));
693
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
694
695
696
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
697
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
698
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
699
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
700
        }
701

Shucai Xiao's avatar
Shucai Xiao committed
702
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
703
704
    }

Paul's avatar
Paul committed
705
    instruction_ref
706
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
707
    {
708
        int64_t axis = 1;
709
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
710
        {
711
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
712
        }
713
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
714
715
    }

716
    instruction_ref
717
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
718
719
    {
        op::squeeze op;
720
        literal s = parse_value(info.attributes.at("axes"));
721
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
722
        return prog.add_instruction(op, make_contiguous(args[0]));
723
724
725
    }

    instruction_ref
726
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
727
728
    {
        op::unsqueeze op;
729
        literal s = parse_value(info.attributes.at("axes"));
730
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
731
        return prog.add_instruction(op, make_contiguous(args[0]));
732
733
    }

Scott Thornton's avatar
Scott Thornton committed
734
    instruction_ref
735
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
736
    {
Shucai Xiao's avatar
Shucai Xiao committed
737
        // change to hande axis to be negative values
738
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
739
740
741
742
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

743
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
744
745
746
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
747

748
    instruction_ref
749
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
750
    {
751
        int axis = 0;
752
        if(contains(info.attributes, "axis"))
753
        {
754
            axis = parse_value(info.attributes.at("axis")).at<int>();
755
        }
756

757
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
758
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
759
760
    }

761
    instruction_ref
762
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
763
764
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
787
        {
788
            literal s = parse_value(info.attributes.at("axes"));
789
790
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
791
792

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
793
        {
Shucai Xiao's avatar
Shucai Xiao committed
794
795
796
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
797
        }
Shucai Xiao's avatar
Shucai Xiao committed
798
        else if(contains(info.attributes, "ends"))
799
        {
800
            op.ends = get_indices(info.attributes.at("ends"));
801
        }
Shucai Xiao's avatar
Shucai Xiao committed
802
803
804
805
806
807
808
809

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
810
        {
811
            literal s = parse_value(info.attributes.at("starts"));
812
813
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
814

815
816
817
        return prog.add_instruction(op, args[0]);
    }

818
819
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
820
    {
821
        literal v = parse_value(info.attributes.at("value"));
822
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
823
        if(v.get_shape().elements() == 0)
824
825
826
827
        {
            return prog.add_literal(literal{});
        }

828
        auto dim_size = info.attributes.at("value").t().dims_size();
829
830
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
831
        {
832
            migraphx::shape scalar_shape{v.get_shape().type()};
833
834
835
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
836
837
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
838

Paul's avatar
Paul committed
839
    instruction_ref
840
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
841
842
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
843
        float beta  = 1.0f;
Paul's avatar
Paul committed
844
845
        bool transa = false;
        bool transb = false;
846
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
847
        {
848
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
849
        }
850
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
851
        {
852
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
853
        }
854
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
855
        {
856
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
857
        }
858
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
859
        {
860
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
861
        }
862
863
864
865
866
867

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

868
869
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
870
871
        if(args.size() == 3)
        {
872
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
873
            {
Shucai Xiao's avatar
Shucai Xiao committed
874
                auto out_lens   = l1->get_shape().lens();
875
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
876
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
877
878
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
879
                {
880
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
881
                }
882
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
883
            }
Paul's avatar
Paul committed
884
        }
885
886

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
887
888
    }

889
    template <class Op>
890
    instruction_ref
891
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
892
    {
Shucai Xiao's avatar
Shucai Xiao committed
893
894
        auto l0      = args[0];
        auto l1      = args[1];
895
896
897
898
899
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
900
        if(l0_lens.size() == 1)
901
902
903
904
905
906
907
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
908
        if(l1_lens.size() == 1)
909
910
911
912
913
914
915
916
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
917
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
918
919
920
921
922
923
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
924
            l0_broadcasted_lens = output_lens;
925
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
926
            l1_broadcasted_lens = output_lens;
927
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
928
            if(l0_lens != l0_broadcasted_lens)
929
930
931
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
932
            if(l1_lens != l1_broadcasted_lens)
933
934
935
936
937
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

938
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
939
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
940
        if(is_a_prepended)
941
942
943
944
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
945
        if(is_b_appended)
946
947
948
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
949

950
951
952
        return dot_res;
    }

953
    instruction_ref
954
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
955
    {
Scott Thornton's avatar
Scott Thornton committed
956
957
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
958
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
959
        if(contains(info.attributes, "epsilon"))
960
        {
961
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
962
        }
963
        if(contains(info.attributes, "momentum"))
964
        {
965
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
966
        }
967
        if(contains(info.attributes, "spatial"))
968
        {
969
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
970
971
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
972
        }
Paul's avatar
Paul committed
973
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
974
        return prog.add_instruction(op, std::move(args));
975
976
    }

977
978
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
979
980
981
982
983
984
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
985
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
986
        {
987
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1012
1013
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1014
    {
Khalique's avatar
Khalique committed
1015
        float alpha = 0.01; // default alpha val for leaky relu
1016
        if(contains(info.attributes, "alpha"))
1017
        {
1018
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1019
1020
1021
1022
1023
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1024
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1025
1026
    {
        float alpha = 1.0; // default alpha val for elu
1027
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1028
        {
1029
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1030
1031
1032
1033
1034
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1035
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1036
1037
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1038
1039
1040
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1041
1042
1043
1044
1045
1046
1047
1048
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1049
1050
1051
1052
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1053
1054
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1055
1056
1057
    {
        float scale = 1.0;
        std::vector<float> bias{};
1058
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1059
        {
1060
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1061
1062
        }

1063
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1064
        {
1065
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1066
1067
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1068
1069
1070
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1071

Shucai Xiao's avatar
Shucai Xiao committed
1072
1073
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1074

1075
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1076
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1077
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1078
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1079
    }
Khalique's avatar
Khalique committed
1080

Khalique's avatar
Khalique committed
1081
    instruction_ref
1082
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1083
1084
    {
        std::vector<int64_t> perm{};
1085
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1086
        {
1087
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1088
1089
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1090
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1091
1092
    }

1093
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1094
1095
1096
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
1097
        if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1098
        {
1099
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1100
1101
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1102
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1103
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1104
1105
1106
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1107
        if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1108
        {
1109
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1110
        }
1111
        if(contains(info.attributes, "mode"))
Khalique's avatar
Khalique committed
1112
        {
1113
            auto mode = info.attributes.at("mode").s();
Khalique's avatar
Khalique committed
1114
1115
1116
1117
1118
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1119
1120
1121
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1122
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1123
1124
    {
        if(args.size() != 1)
1125
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1138
1139
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1140
1141
1142
1143
1144
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1145
        if(contains(info.attributes, "dtype"))
1146
        {
1147
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1148
        }
Shucai Xiao's avatar
Shucai Xiao committed
1149
        shape::type_t type = get_type(dtype);
1150

1151
        if(contains(info.attributes, "input_as_shape"))
1152
        {
1153
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1154
1155
        }

1156
        if(contains(info.attributes, "value"))
1157
        {
1158
            value = parse_value(info.attributes.at("value")).at<float>();
1159
1160
        }

1161
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1162
        {
1163
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1164
1165
        }

1166
1167
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1168
            if(args.size() != 1)
1169
            {
1170
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1171
1172
            }

1173
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1174
            {
1175
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1176
                               "at the same time");
1177
1178
            }

1179
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1180
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1181

1182
1183
1184
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1185
1186
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1187
1188
1189
        }
        else if(input_as_shape == 0)
        {
1190
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1191
            {
1192
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1193
1194
            }

1195
            literal ls = parse_value(info.attributes.at("shape"));
1196
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1197
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1198
            migraphx::shape s{type, dims};
1199
1200
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1201
1202
1203
        }
        else
        {
1204
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1205
1206
1207
        }
    }

1208
1209
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1210
1211
    {
        literal l_val{};
1212
        if(contains(info.attributes, "value"))
1213
        {
1214
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1215
            if(l_val.get_shape().elements() != 1)
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1227

Shucai Xiao's avatar
Shucai Xiao committed
1228
        if(args.empty())
1229
        {
Shucai Xiao's avatar
Shucai Xiao committed
1230
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1231
1232
1233
        }
        else
        {
1234
1235
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1236
            if(args[0]->get_shape().elements() == 0)
1237
            {
1238
                s = migraphx::shape{type, {1}, {0}};
1239
            }
1240
1241
1242
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1243
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1244

1245
1246
1247
1248
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1249

Shucai Xiao's avatar
Shucai Xiao committed
1250
            literal l_out{};
1251
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1252
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1253
                // l_val contains only one element
1254
                std::vector<val_type> out_vec(s.elements(), val.front());
1255
1256
1257
1258
1259
1260
1261
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1262
    instruction_ref
1263
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1264
    {
Shucai Xiao's avatar
Shucai Xiao committed
1265
        auto in_lens             = args[0]->get_shape().lens();
1266
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1267
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1268
1269
1270
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1271
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1272
1273
    }

Shucai Xiao's avatar
Shucai Xiao committed
1274
    std::vector<instruction_ref>
1275
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1276
1277
    {
        migraphx::shape input_shape = args[0]->get_shape();
1278
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1279

1280
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1281
        {
1282
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1283
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1284
1285
1286
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1287
1288
1289
1290
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1291
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1292
        {
1293
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1294
1295
        }

1296
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1297
1298
        if(direction == "bidirectional")
        {
1299
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1300
1301
1302
        }
        else if(direction == "reverse")
        {
1303
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1304
1305
        }

1306
        std::vector<std::string> vec_names{"tanh"};
1307
        if(contains(info.attributes, "activations"))
1308
        {
1309
            auto names = info.attributes.at("activations").strings();
1310
            vec_names.clear();
1311
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1312
1313
1314
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1315
1316
        }

1317
1318
1319
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1320
        if(name_it != vec_names.end())
1321
1322
1323
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1324

Shucai Xiao's avatar
Shucai Xiao committed
1325
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1326
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1327
        // if only one actv function is provided, we use it in both
1328
        // forward and reverse direction
1329
        if(dirct == op::rnn_direction::bidirectional)
1330
        {
Shucai Xiao's avatar
Shucai Xiao committed
1331
            if(vec_names.size() == 1)
1332
1333
1334
1335
1336
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1337
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1338
1339
1340
1341
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1342

Shucai Xiao's avatar
Shucai Xiao committed
1343
1344
        // To be added later
        float clip = 0.0;
1345
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1346
        {
1347
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1348
1349
        }

1350
1351
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1352
        if(args.size() < 6)
1353
1354
1355
1356
1357
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1358
1359
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1360
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1361

1362
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1363
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1364

Shucai Xiao's avatar
Shucai Xiao committed
1365
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1366
1367
    }

1368
    std::vector<instruction_ref>
1369
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1370
1371
1372
1373
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1374
        if(contains(info.attributes, "hidden_size"))
1375
        {
1376
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1377
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1378
1379
1380
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1381
1382
1383
1384
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1385
        if(contains(info.attributes, "direction"))
1386
        {
1387
            direction = info.attributes.at("direction").s();
1388
1389
        }

1390
        op::rnn_direction dirct = op::rnn_direction::forward;
1391
1392
        if(direction == "bidirectional")
        {
1393
            dirct = op::rnn_direction::bidirectional;
1394
1395
1396
        }
        else if(direction == "reverse")
        {
1397
            dirct = op::rnn_direction::reverse;
1398
1399
        }

1400
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1401
        if(contains(info.attributes, "activations"))
1402
        {
1403
            auto names = info.attributes.at("activations").strings();
1404
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1405
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1406
1407
1408
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1409
1410
        }

1411
        // need 4 activation functions
1412
        if(dirct == op::rnn_direction::bidirectional)
1413
        {
Shucai Xiao's avatar
Shucai Xiao committed
1414
            // 4 activation functions are used in the bidirectional
1415
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1416
1417
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1418
1419
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1420
1421
1422
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1423
            if(vec_names.size() == 1)
1424
            {
1425
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1426
            }
1427
            else if(vec_names.size() == 2)
1428
            {
1429
1430
1431
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1432
            }
1433
            else if(vec_names.size() == 3)
1434
            {
1435
                vec_names.push_back(vec_names.at(2));
1436
1437
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1438
        else
1439
        {
1440
            if(vec_names.size() == 1)
1441
            {
1442
                vec_names.push_back(vec_names.at(0));
1443
1444
1445
            }
        }

1446
1447
1448
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1449
        if(name_it != vec_names.end())
1450
1451
1452
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1453

Shucai Xiao's avatar
Shucai Xiao committed
1454
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1455
1456
1457
1458
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1459
1460

        float clip = 0.0;
1461
        if(contains(info.attributes, "clip"))
1462
        {
1463
            clip = parse_value(info.attributes.at("clip")).at<float>();
1464
1465
1466
        }

        int linear_before_reset = 0;
1467
        if(contains(info.attributes, "linear_before_reset"))
1468
        {
1469
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1470
1471
        }

Shucai Xiao's avatar
Shucai Xiao committed
1472
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1473
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1474
1475
1476
1477
1478
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1479
1480
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1481
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1482
            std::move(args));
1483
1484

        // second output for last gru output
1485
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1486

Shucai Xiao's avatar
Shucai Xiao committed
1487
        return {hidden_states, last_output};
1488
1489
    }

Shucai Xiao's avatar
Shucai Xiao committed
1490
    std::vector<instruction_ref>
1491
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1492
1493
1494
1495
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1496
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1497
        {
1498
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1499
1500
1501
1502
1503
1504
1505
1506
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1507
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1508
        {
1509
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1510
1511
        }

Shucai Xiao's avatar
Shucai Xiao committed
1512
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1513
1514
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1515
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1516
1517
1518
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1519
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1520
        }
Shucai Xiao's avatar
Shucai Xiao committed
1521
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1522
        {
Shucai Xiao's avatar
Shucai Xiao committed
1523
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1524
1525
1526
1527
1528
1529
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1530
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1531
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1532
        {
1533
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1534
1535
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1536
1537
1538
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1539
1540
1541
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1542
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1543
1544
1545
1546
1547
1548
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1549
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1550
1551
1552
1553
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1554
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1555
1556
1557
1558
1559
1560
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1561
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1562
1563
1564

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1565
1566
1567
1568
1569
1570
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1571
1572
1573
1574
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1575
1576
1577
1578
1579
1580
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1581
1582
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1583
1584
1585
1586
1587
1588
1589
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1590
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1591

Shucai Xiao's avatar
Shucai Xiao committed
1592
1593
1594
1595
1596
1597
1598
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1599
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1600

Shucai Xiao's avatar
Shucai Xiao committed
1601
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1602
1603
1604
1605
1606
1607
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1608
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1609
1610
1611

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1612
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1613
1614
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1615
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1616
1617
1618
            }
        }

1619
1620
1621
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1622
        if(name_it != vec_names.end())
1623
1624
1625
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1626
1627

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1628
1629
1630
1631
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1632
1633

        float clip = 0.0;
1634
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1635
        {
1636
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1637
1638
1639
        }

        int input_forget = 0;
1640
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
1641
        {
1642
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1643
1644
1645
1646
1647
1648
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1649
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1650
1651
1652
1653
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1654
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1655
1656

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1657
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1658
1659
1660
1661
1662
1663

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1664

Shucai Xiao's avatar
Shucai Xiao committed
1665
    template <class T>
1666
1667
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1668
1669
1670
1671
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1672
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1673
        std::iota(axes.begin(), axes.end(), 0);
1674
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
1675
1676
        {
            axes.clear();
1677
            auto&& attr_axes = info.attributes["axes"].ints();
1678
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1679
1680
1681
        }

        int keep_dims = 1;
1682
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
1683
        {
1684
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1685
1686
1687
1688
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1689
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1690
1691
1692
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1693
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1694
            return prog.add_instruction(op::squeeze{axes}, ins);
1695
1696
        }
    }
1697

Shucai Xiao's avatar
Shucai Xiao committed
1698
    instruction_ref
1699
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1700
1701
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
1702
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1703
1704
1705
    }

    instruction_ref
1706
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1707
1708
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1709
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1710
1711
1712
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

1713
1714
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1715
    {
1716
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1717
1718
1719
        return prog.add_instruction(op::log{}, sum_ins);
    }

1720
1721
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1722
1723
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
1724
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1725
1726
1727
        return prog.add_instruction(op::log{}, sum_ins);
    }

1728
1729
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1730
1731
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1732
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1733
1734
    }

Shucai Xiao's avatar
Shucai Xiao committed
1735
    instruction_ref
1736
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
1737
    {
1738
        if(!contains(info.attributes, "to"))
1739
1740
1741
1742
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

1743
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
1744
1745
1746
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1747

1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

Paul's avatar
Paul committed
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1813
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1814
1815
1816
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
1833
1834
1835
    void parse_graph(const onnx::GraphProto& graph)
    {
        nodes = get_nodes(graph);
1836
        for(auto&& f : graph.initializer())
1837
1838
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1839
1840
1841
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1842
1843
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1844
1845
            {
                // TODO: Get shape of input parameter
1846
                shape s            = parse_type(input.type(), batch_size);
1847
1848
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1849
        }
Paul's avatar
Paul committed
1850
        for(auto&& output : graph.output())
Paul's avatar
Paul committed
1851
        {
Paul's avatar
Paul committed
1852
            this->parse_node(output.name());
Paul's avatar
Paul committed
1853
        }
Shucai Xiao's avatar
Shucai Xiao committed
1854

1855
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
1856
        auto prog_output = graph.output();
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
1877
1878
    }

Shucai Xiao's avatar
Shucai Xiao committed
1879
    void parse_undefined(const std::string& name)
1880
    {
Shucai Xiao's avatar
Shucai Xiao committed
1881
        auto ins           = prog.add_instruction(op::undefined{});
1882
1883
1884
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1885
    void parse_node(const std::string& name)
Paul's avatar
Paul committed
1886
    {
Paul's avatar
Paul committed
1887
        if(name.empty())
Paul's avatar
Paul committed
1888
            MIGRAPHX_THROW("Onnx node must have a name");
Paul's avatar
Paul committed
1889
1890
1891
1892
1893
1894
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1895
                if(input.empty())
Paul's avatar
Paul committed
1896
                {
Shucai Xiao's avatar
Shucai Xiao committed
1897
                    this->parse_undefined(input);
Paul's avatar
Paul committed
1898
                }
Shucai Xiao's avatar
Shucai Xiao committed
1899
                else if(nodes.count(input) > 0)
Paul's avatar
Paul committed
1900
                {
Shucai Xiao's avatar
Shucai Xiao committed
1901
1902
                    assert(name != input);
                    this->parse_node(input);
Paul's avatar
Paul committed
1903
                }
1904
                args.push_back(instructions.at(input));
Paul's avatar
Paul committed
1905
            }
Paul's avatar
Paul committed
1906
            std::vector<instruction_ref> result;
Paul's avatar
Paul committed
1907
1908
            if(ops.count(node.op_type()) == 0)
            {
1909
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
Paul's avatar
Paul committed
1910
1911
1912
            }
            else
            {
1913
1914
                std::size_t output_num = static_cast<std::size_t>(node.output().size());
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
Paul's avatar
Paul committed
1915
            }
Paul's avatar
Paul committed
1916
            // Even no output nodes produce output in migraphx
Paul's avatar
Paul committed
1917
            if(node.output().empty() and result.size() == 1)
Paul's avatar
Paul committed
1918
1919
            {
                instructions[name] = result.front();
Paul's avatar
Paul committed
1920
1921
1922
            }
            else
            {
1923
                auto output_num = std::min<std::size_t>(node.output().size(), result.size());
Shucai Xiao's avatar
Shucai Xiao committed
1924
                std::transform(node.output().begin(),
1925
                               node.output().begin() + output_num,
Shucai Xiao's avatar
Shucai Xiao committed
1926
                               result.begin(),
Paul's avatar
Paul committed
1927
                               std::inserter(instructions, instructions.end()),
Shucai Xiao's avatar
Shucai Xiao committed
1928
                               [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
            }
        }
    }

    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    static node_map get_nodes(const onnx::GraphProto& graph)
    {
        std::unordered_map<std::string, onnx::NodeProto> result;
Paul's avatar
Paul committed
1946
        std::size_t n = 0;
Paul's avatar
Paul committed
1947
1948
        for(auto&& node : graph.node())
        {
Paul's avatar
Paul committed
1949
            if(node.output().empty())
Paul's avatar
Paul committed
1950
            {
Paul's avatar
Paul committed
1951
                if(node.name().empty())
Paul's avatar
Paul committed
1952
1953
1954
1955
1956
1957
1958
1959
1960
                {
                    result["migraphx_unamed_node_" + std::to_string(n)] = node;
                    n++;
                }
                else
                {
                    result[node.name()] = node;
                }
            }
Paul's avatar
Paul committed
1961
1962
1963
1964
1965
1966
1967
1968
            for(auto&& output : node.output())
            {
                result[output] = node;
            }
        }
        return result;
    }

Paul's avatar
Paul committed
1969
1970
1971
1972
1973
1974
    static std::vector<int64_t> get_indices(const onnx::AttributeProto& attr)
    {
        std::vector<int64_t> result;
        literal s = parse_value(attr);
        s.visit([&](auto v) { copy(v, std::back_inserter(result)); });
        // Clamp large indices to -1
Paul's avatar
Paul committed
1975
1976
1977
1978
1979
        std::replace_if(
            result.begin(),
            result.end(),
            [](auto x) { return x > int64_t{std::numeric_limits<std::int32_t>::max()} / 2; },
            -1);
Paul's avatar
Paul committed
1980
1981
1982
        return result;
    }

Paul's avatar
Paul committed
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
1997
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1998
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
1999
2000
2001
2002
2003
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
2004
2005
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
2006
2007
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
2008
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
2009
2010
2011
2012
2013
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
2014
2015
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
2016
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
2017
2018
            switch(t.data_type())
            {
2019
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
2020
2021
2022
2023
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
2024
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
2025
2026
2027
2028
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
2029
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
2030
2031
2032
2033
2034
2035
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
2036
2037
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
2038
            MIGRAPHX_THROW("Invalid tensor type");
2039
        }
Paul's avatar
Paul committed
2040
2041
2042
2043
2044
2045
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
2046
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
2047
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
2048
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2049
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
2050
2051
2052
2053
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
2054
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2055
        {
Khalique's avatar
Khalique committed
2056
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2057
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2058
2059
2060
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2061
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2062
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2063
        }
Paul's avatar
Paul committed
2064
2065
2066
2067
2068
2069
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2070
2071
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2072
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2073
2074
    }

Khalique's avatar
Khalique committed
2075
    static literal
2076
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2077
    {
Khalique's avatar
Khalique committed
2078
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2079
        if(dims.empty())
2080
            return literal{{shape_type}, data};
2081
2082
2083
        return literal{{shape_type, dims}, data};
    }

2084
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2085
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2086
2087
    {
        if(dims.empty())
2088
            return literal{{shape_type}, data.begin(), data.end()};
2089
        return literal{{shape_type, dims}, data.begin(), data.end()};
2090
2091
    }

2092
    static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
Paul's avatar
Paul committed
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2103
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2104
2105
2106
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
2107
2108
2109
2110
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2111
2112
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2113
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2114
2115
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2116
        auto&& tensor_dims = t.tensor_type().shape().dim();
2117
2118
2119
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2120
2121
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2122
                           {
2123
2124
2125
                               if(static_cast<int>(d.dim_value()) <= 0)
                                   return batch_size;
                               return d.dim_value();
2126
                           }
2127
                           return batch_size;
2128
                       });
2129
2130
2131
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2132
2133
        return {shape_type, dims};
    }
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2156
2157
2158

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2159
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2160
2161
2162
2163
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2164
2165
};

Paul Fultz II's avatar
Paul Fultz II committed
2166
2167
template <class... Ts>
program parse_onnx_from(onnx_options options, Ts&&... xs)
Paul's avatar
Paul committed
2168
2169
{
    onnx_parser parser;
2170
    parser.batch_size = options.batch_size;
Paul's avatar
Paul committed
2171
2172
2173
2174
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
Paul Fultz II's avatar
Paul Fultz II committed
2175
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2176
2177
2178
2179
2180
2181
2182
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
Paul Fultz II's avatar
Paul Fultz II committed
2183
    parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2184
2185
2186
2187
#endif
    return std::move(parser.prog);
}

Paul Fultz II's avatar
Paul Fultz II committed
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
program parse_onnx(const std::string& name, onnx_options options)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

program parse_onnx_buffer(const std::string& buffer, onnx_options options)
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options)
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2204
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2205
} // namespace migraphx