onnx.cpp 83.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
27
28
29
30
31
32
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
33
    using op_func =
34
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
35
36
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
37
38
39
40
    program prog                  = program();
    bool is_pytorch               = false;
    std::size_t default_dim_value = 1;
    std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
Paul's avatar
Paul committed
41
42

    std::unordered_map<std::string, op_func> ops;
43
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
44
45
46

    onnx_parser()
    {
47
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
48
        add_generic_op("Abs", op::abs{});
49
50
51
52
53
54
55
56
57
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
58
        add_generic_op("Erf", op::erf{});
59
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
60
        add_generic_op("Dropout", op::identity{});
61
62
        add_generic_op("Log", op::log{});
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
63
        add_generic_op("Identity", op::identity{});
kahmed10's avatar
kahmed10 committed
64
        add_generic_op("Reciprocal", op::recip{});
65
66
67
68
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
69
        add_generic_op("Sin", op::sin{});
70
        add_generic_op("Sinh", op::sinh{});
71
        add_generic_op("Sqrt", op::sqrt{});
72
73
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
74

Khalique's avatar
Khalique committed
75
76
77
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
78
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
79
        add_binary_op("PRelu", op::prelu{});
80
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
81

Khalique's avatar
Khalique committed
82
83
84
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
85

86
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
87
88
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
89
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
90
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
91
        add_mem_op("Clip", &onnx_parser::parse_clip);
92
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
93
        add_mem_op("Constant", &onnx_parser::parse_constant);
94
95
96
97
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
98
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
99
100
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
101
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
102
        add_mem_op("Gather", &onnx_parser::parse_gather);
Paul's avatar
Paul committed
103
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
104
105
106
107
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
108
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
109
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
110
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
111
112
113
114
        add_mem_op("LRN", &onnx_parser::parse_lrn);
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
kahmed10's avatar
kahmed10 committed
115
        add_mem_op("OneHot", &onnx_parser::parse_onehot);
Shucai Xiao's avatar
Shucai Xiao committed
116
117
118
119
120
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
121
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
122
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
123
124
125
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
126
127
128
129
130
131
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Pad", &onnx_parser::parse_pad);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
132
        add_mem_op("Split", &onnx_parser::parse_split);
133
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
kahmed10's avatar
kahmed10 committed
134
        add_mem_op("Tile", &onnx_parser::parse_tile);
135
136
137
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
138
139
140
141
142
143
144

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
145
146
147
148
149
150
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
151
152
153
154
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
155
156
157
158
159
160
161
162
163
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
164
165
166
167
168
169
170
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
171
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
172
173
174
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
175

176
    template <class T>
Khalique's avatar
Khalique committed
177
    void add_binary_op(std::string name, T x)
178
    {
179
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
180
            if(args.size() != 2)
Paul's avatar
Paul committed
181
                MIGRAPHX_THROW("binary operators should have 2 operands");
182
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
183
            {
184
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
185
186
                if(broadcasted != 0)
                {
187
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
188
189
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
190
191
                    return prog.add_instruction(x, args[0], l);
                }
192
                return prog.add_instruction(x, args);
193
            }
Paul's avatar
Paul committed
194
            else
195
            {
Khalique's avatar
Khalique committed
196
                return add_broadcastable_binary_op(args[0], args[1], x);
197
198
199
200
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
201
202
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
203
204
205
206
207
208
209
210
211
212
213
214
215
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
216
        if(s0.size() > s1.size())
217
218
219
220
221
222
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
223
224
225
226
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
227
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
228
                           if(a != b and a != 1 and b != 1)
229
                           {
Shucai Xiao's avatar
Shucai Xiao committed
230
231
232
233
234
235
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
236
237
238
239

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
240
241
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
242
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
243
244
245
246
247
248
249
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
250
251
252
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
253
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
254
255
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
256
257
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
258
            auto out_lens = compute_broadcasted_lens(s0, s1);
259
260
261
262
263
264
265
266
267

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
268
269
270
271
272
273
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
274
275
    }

Paul's avatar
Paul committed
276
    template <class T>
Paul's avatar
Paul committed
277
278
    void add_generic_op(std::string name, T x)
    {
279
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
280
281
282
283
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
284
    template <class T>
Khalique's avatar
Khalique committed
285
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
286
    {
287
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
288
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
289
290
291
292
293
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
294
        });
Khalique's avatar
Khalique committed
295
296
    }

kahmed10's avatar
kahmed10 committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

316
317
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
318
                            const std::vector<int64_t>& padding,
319
320
321
322
323
                            Op& op,
                            float pad_val = 0)
    {
        if(padding[0] != padding[2] || padding[1] != padding[3])
        {
324
325
326
            ins = prog.add_instruction(
                op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
                ins);
327
328
329
330
331
332
333
334
        }
        else
        {
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
        }
    }

335
336
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
337
    {
kahmed10's avatar
kahmed10 committed
338
339
340
341
342
343
344
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
345
        {
kahmed10's avatar
kahmed10 committed
346
347
348
349
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
350
        }
kahmed10's avatar
kahmed10 committed
351
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
352
        {
kahmed10's avatar
kahmed10 committed
353
354
355
356
357
358
359
360
361
362
363
364
365
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
366
        }
kahmed10's avatar
kahmed10 committed
367
368
369
370
371
372
373
374
375
376
377
378
379

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
380
381
    }

Shucai Xiao's avatar
Shucai Xiao committed
382
    template <class Op>
383
384
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
385
    {
386
        int64_t axis = 1;
387
        if(contains(info.attributes, "axis"))
388
        {
389
            axis = parse_value(info.attributes.at("axis")).at<int>();
390
391
        }

392
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
393
394
    }

Shucai Xiao's avatar
Shucai Xiao committed
395
    template <class Op>
396
397
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
398
    {
399
        int64_t axis = 0;
400
        if(contains(info.attributes, "axis"))
401
        {
402
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
403
404
        }

Shucai Xiao's avatar
Shucai Xiao committed
405
        int keep_dims = 1;
406
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
407
        {
408
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
409
410
        }

Shucai Xiao's avatar
Shucai Xiao committed
411
        if(keep_dims == 0)
412
        {
413
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
414
            return prog.add_instruction(op::squeeze{{axis}}, ins);
415
416
417
        }
        else
        {
418
            return prog.add_instruction(Op{axis}, std::move(args));
419
        }
420
421
    }

422
423
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
424
                                               node_info info,
425
                                               Op& op,
426
427
428
429
                                               std::array<std::size_t, 2> k_lens,
                                               std::array<std::size_t, 2> dilation,
                                               const std::vector<std::size_t>& in_lens,
                                               float value = 0.0f)
430
    {
431
        if(!contains(info.attributes, "auto_pad"))
432
433
434
435
        {
            return ins;
        }

436
        auto auto_pad = info.attributes["auto_pad"].s();
437
438
        if(auto_pad.find("SAME") != std::string::npos)
        {
439
440
441
442
443
444
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            std::vector<int64_t> padding(in_lens.size());
            calculate_padding(
                0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
            calculate_padding(
                1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
445

446
            check_asym_padding(ins, padding, op, value);
447
448
449
450
451
        }

        return ins;
    }

452
    template <class Op>
Paul's avatar
Paul committed
453
    instruction_ref
454
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
455
    {
456
        Op op;
457
458
        auto l0      = args[0];
        auto weights = args[1];
459
        std::vector<int64_t> padding;
460
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
461
        {
462
            if(contains(info.attributes, "auto_pad"))
463
            {
464
465
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
466
                {
467
468
                    MIGRAPHX_THROW(
                        "PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
469
                }
470
            }
471
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
472
            if(padding.size() != 4)
473
            {
474
                MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
475
            }
476
            check_asym_padding(l0, padding, op);
Paul's avatar
Paul committed
477
        }
478
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
479
        {
480
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
481
        }
482
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
483
        {
484
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
485
        }
486
        if(contains(info.attributes, "auto_pad"))
487
        {
488
            auto s = info.attributes["auto_pad"].s();
wsttiger's avatar
fixes  
wsttiger committed
489
            if(s.find("SAME") != std::string::npos)
490
            {
491
492
493
494
495
496
                op.padding_mode                 = op::padding_mode_t::same;
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
497
                padding.resize(input_dims.size());
498
499
500
501
502
503
                calculate_padding(
                    0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(
                    1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                check_asym_padding(l0, padding, op);
504
            }
505
506
507
508
509

            auto in_lens                      = args[0]->get_shape().lens();
            auto weight_lens                  = args[1]->get_shape().lens();
            std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
            l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
510
        }
511
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
512
        {
513
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
514
        }
kahmed10's avatar
kahmed10 committed
515
516
517
518
519

        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

520
521
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
522
523
524
525
526
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
        bool asymm_padding = false;
527
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
528
        {
529
            if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
530
            {
531
532
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
533
534
535
536
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
            }
537
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
538
539
540
541
542
543
544
545
546
547
548
549
550
551
            if(padding.size() != 4)
            {
                MIGRAPHX_THROW("padding should have 4 values");
            }
            if(padding[0] != padding[2] || padding[1] != padding[3])
            {
                asymm_padding = true;
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
552
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
553
        {
554
            copy(info.attributes["strides"].ints(), op.stride.begin());
kahmed10's avatar
kahmed10 committed
555
        }
556
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
557
        {
558
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
559
        }
560
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
561
        {
562
563
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
564
565
566
567
568
569
570
571
572
573
            {
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

574
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
575
        {
576
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        }

        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
        std::vector<int64_t> curr_shape{dims[2], dims[3]};
        if(asymm_padding)
        {
            op::slice slice_op;
            slice_op.axes   = {0, 1, 2, 3};
            slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
            slice_op.ends   = {
                dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};

            l1 = prog.add_instruction(slice_op, l1);
        }

593
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
594
595
        {
            std::vector<int64_t> output_padding;
596
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
597
598
599
600
            output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
            l1             = prog.add_instruction(op::pad{output_padding}, l1);
        }

601
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
602
603
        {
            std::vector<int64_t> output_shape;
604
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
            dims       = to_int64_vector(l1->get_shape().lens());
            curr_shape = {dims[2], dims[3]};
            if(curr_shape != output_shape)
            {
                std::vector<int64_t> target_padding = {0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       output_shape[0] - curr_shape[0],
                                                       output_shape[1] - curr_shape[1]};
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
622
    }
Paul's avatar
Paul committed
623

624
625
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
626
    {
Khalique's avatar
Khalique committed
627
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
628
        auto l0 = args[0];
Khalique's avatar
Khalique committed
629
        if(starts_with(name, "Global"))
630
        {
Khalique's avatar
Khalique committed
631
632
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
633
        }
634

635
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
636
        {
637
            if(contains(info.attributes, "auto_pad"))
638
            {
639
                auto s = info.attributes["auto_pad"].s();
640
641
642
643
644
645
646
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

647
            std::vector<std::int64_t> padding;
648
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
649
            if(padding.size() != 4)
650
            {
651
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
652
            }
653
654
655
656
            float pad_val = 0;
            if(op.mode == "max")
                pad_val = std::numeric_limits<float>::lowest();
            check_asym_padding(l0, padding, op, pad_val);
Paul's avatar
Paul committed
657
        }
658

659
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
660
        {
661
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
662
        }
663
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
664
        {
665
            copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
Paul's avatar
Paul committed
666
        }
667

668
        if(contains(info.attributes, "auto_pad"))
669
        {
670
671
672
673
674
675
            auto s = info.attributes["auto_pad"].s();
            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }

676
            auto in_lens = args[0]->get_shape().lens();
677
678
679
680
681
682
683
684
            float val    = 0.0f;
            // MaxPool
            if(op.mode == "max")
            {
                val = std::numeric_limits<float>::lowest();
            }

            l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
685
686
        }

687
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
688
689
    }

Paul's avatar
Paul committed
690
    instruction_ref
691
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
692
    {
693
        op::reshape op;
Paul's avatar
Paul committed
694
695
        if(args.size() == 1)
        {
696
            literal s = parse_value(info.attributes.at("shape"));
697
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
698
699
700
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
701
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
702
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
703
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
704
        }
705

Shucai Xiao's avatar
Shucai Xiao committed
706
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
707
708
    }

Paul's avatar
Paul committed
709
    instruction_ref
710
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
711
    {
712
        int64_t axis = 1;
713
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
714
        {
715
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
716
        }
717
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
718
719
    }

720
    instruction_ref
721
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
722
723
    {
        op::squeeze op;
724
        literal s = parse_value(info.attributes.at("axes"));
725
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
726
        return prog.add_instruction(op, make_contiguous(args[0]));
727
728
729
    }

    instruction_ref
730
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
731
732
    {
        op::unsqueeze op;
733
        literal s = parse_value(info.attributes.at("axes"));
734
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
735
        return prog.add_instruction(op, make_contiguous(args[0]));
736
737
    }

Scott Thornton's avatar
Scott Thornton committed
738
    instruction_ref
739
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
740
    {
Shucai Xiao's avatar
Shucai Xiao committed
741
        // change to hande axis to be negative values
742
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
743
744
745
746
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

747
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
748
749
750
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
751

752
    instruction_ref
753
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
754
    {
755
        int axis = 0;
756
        if(contains(info.attributes, "axis"))
757
        {
758
            axis = parse_value(info.attributes.at("axis")).at<int>();
759
        }
760

761
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
762
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
763
764
    }

765
    instruction_ref
766
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
767
768
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
791
        {
792
            literal s = parse_value(info.attributes.at("axes"));
793
794
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
795
796

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
797
        {
Shucai Xiao's avatar
Shucai Xiao committed
798
799
800
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
801
        }
Shucai Xiao's avatar
Shucai Xiao committed
802
        else if(contains(info.attributes, "ends"))
803
        {
804
805
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
806
        }
Shucai Xiao's avatar
Shucai Xiao committed
807
808
809
810
811
812
813
814

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
815
        {
816
            literal s = parse_value(info.attributes.at("starts"));
817
818
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
819

kahmed10's avatar
kahmed10 committed
820
821
822
823
824
825
826
        if(op.axes.empty())
        {
            std::vector<int64_t> axes(args[0]->get_shape().lens().size());
            std::iota(axes.begin(), axes.end(), int64_t{0});
            op.axes = axes;
        }

827
828
829
        return prog.add_instruction(op, args[0]);
    }

830
831
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
832
    {
833
        literal v = parse_value(info.attributes.at("value"));
834
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
835
        if(v.get_shape().elements() == 0)
836
837
838
839
        {
            return prog.add_literal(literal{});
        }

840
        auto dim_size = info.attributes.at("value").t().dims_size();
841
842
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
843
        {
844
            migraphx::shape scalar_shape{v.get_shape().type()};
845
846
847
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
848
849
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
850

Paul's avatar
Paul committed
851
    instruction_ref
852
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
853
854
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
855
        float beta  = 1.0f;
Paul's avatar
Paul committed
856
857
        bool transa = false;
        bool transb = false;
858
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
859
        {
860
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
861
        }
862
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
863
        {
864
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
865
        }
866
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
867
        {
868
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
869
        }
870
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
871
        {
872
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
873
        }
874
875
876
877
878
879

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

880
881
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
882
883
        if(args.size() == 3)
        {
884
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
885
            {
Shucai Xiao's avatar
Shucai Xiao committed
886
                auto out_lens   = l1->get_shape().lens();
887
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
888
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
889
890
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
891
                {
892
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
893
                }
894
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
895
            }
Paul's avatar
Paul committed
896
        }
897
898

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
899
900
    }

901
    template <class Op>
902
    instruction_ref
903
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
904
    {
Shucai Xiao's avatar
Shucai Xiao committed
905
906
        auto l0      = args[0];
        auto l1      = args[1];
907
908
909
910
911
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
912
        if(l0_lens.size() == 1)
913
914
915
916
917
918
919
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
920
        if(l1_lens.size() == 1)
921
922
923
924
925
926
927
928
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
929
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
930
931
932
933
934
935
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
936
            l0_broadcasted_lens = output_lens;
937
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
938
            l1_broadcasted_lens = output_lens;
939
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
940
            if(l0_lens != l0_broadcasted_lens)
941
942
943
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
944
            if(l1_lens != l1_broadcasted_lens)
945
946
947
948
949
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

950
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
951
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
952
        if(is_a_prepended)
953
954
955
956
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
957
        if(is_b_appended)
958
959
960
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
961

962
963
964
        return dot_res;
    }

965
    instruction_ref
966
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
967
    {
Scott Thornton's avatar
Scott Thornton committed
968
969
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
970
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
971
        if(contains(info.attributes, "epsilon"))
972
        {
973
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
974
        }
975
        if(contains(info.attributes, "momentum"))
976
        {
977
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
978
        }
979
        if(contains(info.attributes, "spatial"))
980
        {
981
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
982
983
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
984
        }
Paul's avatar
Paul committed
985
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
986
        return prog.add_instruction(op, std::move(args));
987
988
    }

989
990
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
991
992
993
994
995
996
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
997
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
998
        {
999
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1024
1025
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1026
    {
Khalique's avatar
Khalique committed
1027
        float alpha = 0.01; // default alpha val for leaky relu
1028
        if(contains(info.attributes, "alpha"))
1029
        {
1030
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1031
1032
1033
1034
1035
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1036
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1037
1038
    {
        float alpha = 1.0; // default alpha val for elu
1039
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1040
        {
1041
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1042
1043
1044
1045
1046
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1047
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1048
1049
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1050
1051
1052
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1053
1054
1055
1056
1057
1058
1059
1060
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1061
1062
1063
1064
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1065
1066
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1067
1068
1069
    {
        float scale = 1.0;
        std::vector<float> bias{};
1070
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1071
        {
1072
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1073
1074
        }

1075
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1076
        {
1077
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1078
1079
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1080
1081
1082
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1083

Shucai Xiao's avatar
Shucai Xiao committed
1084
1085
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1086

1087
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1088
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1089
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1090
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1091
    }
Khalique's avatar
Khalique committed
1092

Khalique's avatar
Khalique committed
1093
    instruction_ref
1094
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1095
1096
    {
        std::vector<int64_t> perm{};
1097
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1098
        {
1099
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1100
1101
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1102
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1103
1104
    }

1105
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1106
1107
    {
        std::vector<int64_t> pads{};
1108
1109
1110
1111
1112
1113
1114
        if(args.size() >= 2)
        {
            auto pad_arg = args.at(1)->eval();
            check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
            pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
        }
        else if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1115
        {
1116
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1117
1118
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1119
1120
1121
1122
1123
        else
        {
            MIGRAPHX_THROW("PARSE_PAD: pad must be available");
        }

1124
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1125
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1126
1127
1128
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146

        float value = 0.0f;
        // third input is the value
        if(args.size() == 3)
        {
            auto val_ins = args.at(2);
            if(!val_ins->can_eval())
            {
                MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
            }
            auto val_arg = val_ins->eval();
            if(val_arg.get_shape().elements() != 1)
            {
                MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
            }
            value = val_arg.at<float>();
        }
        else if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1147
        {
1148
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1149
        }
1150

1151
        if(contains(info.attributes, "mode"))
Khalique's avatar
Khalique committed
1152
        {
1153
            auto mode = info.attributes.at("mode").s();
Khalique's avatar
Khalique committed
1154
            if(mode != "constant")
1155
1156
1157
            {
                MIGRAPHX_THROW("PARSE_PAD: migraphx currently only supports constant padding");
            }
Khalique's avatar
Khalique committed
1158
1159
1160
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1161
1162
1163
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1164
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1165
1166
    {
        if(args.size() != 1)
1167
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1180
1181
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1182
1183
1184
1185
1186
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1187
        if(contains(info.attributes, "dtype"))
1188
        {
1189
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1190
        }
Shucai Xiao's avatar
Shucai Xiao committed
1191
        shape::type_t type = get_type(dtype);
1192

1193
        if(contains(info.attributes, "input_as_shape"))
1194
        {
1195
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1196
1197
        }

1198
        if(contains(info.attributes, "value"))
1199
        {
1200
            value = parse_value(info.attributes.at("value")).at<float>();
1201
1202
        }

1203
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1204
        {
1205
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1206
1207
        }

1208
1209
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1210
            if(args.size() != 1)
1211
            {
1212
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1213
1214
            }

1215
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1216
            {
1217
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1218
                               "at the same time");
1219
1220
            }

1221
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1222
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1223

1224
1225
1226
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1227
1228
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1229
1230
1231
        }
        else if(input_as_shape == 0)
        {
1232
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1233
            {
1234
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1235
1236
            }

1237
            literal ls = parse_value(info.attributes.at("shape"));
1238
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1239
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1240
            migraphx::shape s{type, dims};
1241
1242
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1243
1244
1245
        }
        else
        {
1246
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1247
1248
1249
        }
    }

1250
1251
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1252
1253
    {
        literal l_val{};
1254
        if(contains(info.attributes, "value"))
1255
        {
1256
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1257
            if(l_val.get_shape().elements() != 1)
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1269

Shucai Xiao's avatar
Shucai Xiao committed
1270
        if(args.empty())
1271
        {
Shucai Xiao's avatar
Shucai Xiao committed
1272
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1273
1274
1275
        }
        else
        {
1276
1277
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1278
            if(args[0]->get_shape().elements() == 0)
1279
            {
1280
                s = migraphx::shape{type, {1}, {0}};
1281
            }
1282
1283
1284
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1285
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1286

1287
1288
1289
1290
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1291

Shucai Xiao's avatar
Shucai Xiao committed
1292
            literal l_out{};
1293
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1294
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1295
                // l_val contains only one element
1296
                std::vector<val_type> out_vec(s.elements(), val.front());
1297
1298
1299
1300
1301
1302
1303
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1304
    instruction_ref
1305
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1306
    {
Shucai Xiao's avatar
Shucai Xiao committed
1307
        auto in_lens             = args[0]->get_shape().lens();
1308
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1309
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1310
1311
1312
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1313
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1314
1315
    }

Shucai Xiao's avatar
Shucai Xiao committed
1316
    std::vector<instruction_ref>
1317
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1318
1319
    {
        migraphx::shape input_shape = args[0]->get_shape();
1320
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1321

1322
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1323
        {
1324
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1325
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1326
1327
1328
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1329
1330
1331
1332
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1333
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1334
        {
1335
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1336
1337
        }

1338
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1339
1340
        if(direction == "bidirectional")
        {
1341
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1342
1343
1344
        }
        else if(direction == "reverse")
        {
1345
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1346
1347
        }

1348
        std::vector<std::string> vec_names{"tanh"};
1349
        if(contains(info.attributes, "activations"))
1350
        {
1351
            auto names = info.attributes.at("activations").strings();
1352
            vec_names.clear();
1353
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1354
1355
1356
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1357
1358
        }

1359
1360
1361
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1362
        if(name_it != vec_names.end())
1363
1364
1365
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1366

Shucai Xiao's avatar
Shucai Xiao committed
1367
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1368
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1369
        // if only one actv function is provided, we use it in both
1370
        // forward and reverse direction
1371
        if(dirct == op::rnn_direction::bidirectional)
1372
        {
Shucai Xiao's avatar
Shucai Xiao committed
1373
            if(vec_names.size() == 1)
1374
1375
1376
1377
1378
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1379
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1380
1381
1382
1383
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1384

Shucai Xiao's avatar
Shucai Xiao committed
1385
1386
        // To be added later
        float clip = 0.0;
1387
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1388
        {
1389
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1390
1391
        }

1392
1393
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1394
        if(args.size() < 6)
1395
1396
1397
1398
1399
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1400
1401
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1402
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1403

1404
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1405
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1406

Shucai Xiao's avatar
Shucai Xiao committed
1407
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1408
1409
    }

1410
    std::vector<instruction_ref>
1411
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1412
1413
1414
1415
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1416
        if(contains(info.attributes, "hidden_size"))
1417
        {
1418
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1419
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1420
1421
1422
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1423
1424
1425
1426
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1427
        if(contains(info.attributes, "direction"))
1428
        {
1429
            direction = info.attributes.at("direction").s();
1430
1431
        }

1432
        op::rnn_direction dirct = op::rnn_direction::forward;
1433
1434
        if(direction == "bidirectional")
        {
1435
            dirct = op::rnn_direction::bidirectional;
1436
1437
1438
        }
        else if(direction == "reverse")
        {
1439
            dirct = op::rnn_direction::reverse;
1440
1441
        }

1442
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1443
        if(contains(info.attributes, "activations"))
1444
        {
1445
            auto names = info.attributes.at("activations").strings();
1446
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1447
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1448
1449
1450
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1451
1452
        }

1453
        // need 4 activation functions
1454
        if(dirct == op::rnn_direction::bidirectional)
1455
        {
Shucai Xiao's avatar
Shucai Xiao committed
1456
            // 4 activation functions are used in the bidirectional
1457
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1458
1459
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1460
1461
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1462
1463
1464
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1465
            if(vec_names.size() == 1)
1466
            {
1467
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1468
            }
1469
            else if(vec_names.size() == 2)
1470
            {
1471
1472
1473
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1474
            }
1475
            else if(vec_names.size() == 3)
1476
            {
1477
                vec_names.push_back(vec_names.at(2));
1478
1479
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1480
        else
1481
        {
1482
            if(vec_names.size() == 1)
1483
            {
1484
                vec_names.push_back(vec_names.at(0));
1485
1486
1487
            }
        }

1488
1489
1490
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1491
        if(name_it != vec_names.end())
1492
1493
1494
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1495

Shucai Xiao's avatar
Shucai Xiao committed
1496
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1497
1498
1499
1500
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1501
1502

        float clip = 0.0;
1503
        if(contains(info.attributes, "clip"))
1504
        {
1505
            clip = parse_value(info.attributes.at("clip")).at<float>();
1506
1507
1508
        }

        int linear_before_reset = 0;
1509
        if(contains(info.attributes, "linear_before_reset"))
1510
        {
1511
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1512
1513
        }

Shucai Xiao's avatar
Shucai Xiao committed
1514
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1515
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1516
1517
1518
1519
1520
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1521
1522
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1523
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1524
            std::move(args));
1525
1526

        // second output for last gru output
1527
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1528

Shucai Xiao's avatar
Shucai Xiao committed
1529
        return {hidden_states, last_output};
1530
1531
    }

Shucai Xiao's avatar
Shucai Xiao committed
1532
    std::vector<instruction_ref>
1533
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1534
1535
1536
1537
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1538
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1539
        {
1540
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1541
1542
1543
1544
1545
1546
1547
1548
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1549
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1550
        {
1551
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1552
1553
        }

Shucai Xiao's avatar
Shucai Xiao committed
1554
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1555
1556
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1557
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1558
1559
1560
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1561
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1562
        }
Shucai Xiao's avatar
Shucai Xiao committed
1563
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1564
        {
Shucai Xiao's avatar
Shucai Xiao committed
1565
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1566
1567
1568
1569
1570
1571
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1572
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1573
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1574
        {
1575
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1576
1577
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1578
1579
1580
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1581
1582
1583
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1584
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1585
1586
1587
1588
1589
1590
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1591
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1592
1593
1594
1595
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1596
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1597
1598
1599
1600
1601
1602
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1603
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1604
1605
1606

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1607
1608
1609
1610
1611
1612
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1613
1614
1615
1616
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1617
1618
1619
1620
1621
1622
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1623
1624
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1625
1626
1627
1628
1629
1630
1631
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1632
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1633

Shucai Xiao's avatar
Shucai Xiao committed
1634
1635
1636
1637
1638
1639
1640
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1641
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1642

Shucai Xiao's avatar
Shucai Xiao committed
1643
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1644
1645
1646
1647
1648
1649
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1650
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1651
1652
1653

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1654
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1655
1656
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1657
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1658
1659
1660
            }
        }

1661
1662
1663
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1664
        if(name_it != vec_names.end())
1665
1666
1667
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1668
1669

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1670
1671
1672
1673
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1674
1675

        float clip = 0.0;
1676
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1677
        {
1678
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1679
1680
1681
        }

        int input_forget = 0;
1682
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
1683
        {
1684
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1685
1686
1687
1688
1689
1690
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1691
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1692
1693
1694
1695
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1696
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1697
1698

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1699
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1700
1701
1702
1703
1704
1705

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1706

Shucai Xiao's avatar
Shucai Xiao committed
1707
    template <class T>
1708
1709
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1710
1711
1712
1713
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1714
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1715
        std::iota(axes.begin(), axes.end(), 0);
1716
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
1717
1718
        {
            axes.clear();
1719
            auto&& attr_axes = info.attributes["axes"].ints();
1720
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1721
1722
1723
        }

        int keep_dims = 1;
1724
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
1725
        {
1726
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1727
1728
1729
1730
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1731
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1732
1733
1734
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1735
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1736
            return prog.add_instruction(op::squeeze{axes}, ins);
1737
1738
        }
    }
1739

Shucai Xiao's avatar
Shucai Xiao committed
1740
    instruction_ref
1741
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1742
1743
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
1744
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1745
1746
1747
    }

    instruction_ref
1748
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1749
1750
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1751
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1752
1753
1754
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

1755
1756
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1757
    {
1758
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1759
1760
1761
        return prog.add_instruction(op::log{}, sum_ins);
    }

1762
1763
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1764
1765
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
1766
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1767
1768
1769
        return prog.add_instruction(op::log{}, sum_ins);
    }

1770
1771
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1772
1773
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1774
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1775
1776
    }

Shucai Xiao's avatar
Shucai Xiao committed
1777
    instruction_ref
1778
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
1779
    {
1780
        if(!contains(info.attributes, "to"))
1781
1782
1783
1784
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

1785
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
1786
1787
1788
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1789

1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

kahmed10's avatar
kahmed10 committed
1843
1844
1845
1846
    instruction_ref
    parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        migraphx::argument depth_arg = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1847
        check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported");
kahmed10's avatar
kahmed10 committed
1848
1849
1850
        size_t depth = depth_arg.at<size_t>();

        int64_t axis = -1;
Shucai Xiao's avatar
Shucai Xiao committed
1851
1852
1853
1854
        if(contains(info.attributes, "axis"))
        {
            axis = info.attributes.at("axis").i();
        }
kahmed10's avatar
kahmed10 committed
1855

Shucai Xiao's avatar
Shucai Xiao committed
1856
        std::vector<float> depth_input(depth * depth, 0.0f);
kahmed10's avatar
kahmed10 committed
1857
1858
        for(int i = 0; i < depth; i++)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1859
            depth_input[depth * i + i] = 1.0f;
kahmed10's avatar
kahmed10 committed
1860
1861
        }

Shucai Xiao's avatar
Shucai Xiao committed
1862
1863
1864
1865
1866
1867
1868
1869
        auto type = args[2]->get_shape().type();
        shape s{type, {depth, depth}};
        auto l_val      = prog.add_literal({s, depth_input});
        auto gather_out = prog.add_instruction(op::gather{0}, {l_val, args[0]});

        // Finally, we need a transpose to move the inner most dim to the axis dim
        int n_rank = gather_out->get_shape().lens().size();
        if(axis < -n_rank or axis >= n_rank)
kahmed10's avatar
kahmed10 committed
1870
        {
Shucai Xiao's avatar
Shucai Xiao committed
1871
            MIGRAPHX_THROW("PARSE_ONEHOT: axis out of range");
kahmed10's avatar
kahmed10 committed
1872
        }
Shucai Xiao's avatar
Shucai Xiao committed
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
        std::vector<int64_t> perm(n_rank - 1);
        std::iota(perm.begin(), perm.end(), 0);
        perm.insert(perm.begin() + tuned_axis, n_rank - 1);
        auto tr_out = prog.add_instruction(op::transpose{perm}, gather_out);
        auto lens   = tr_out->get_shape().lens();

        auto off_val       = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
        auto on_val        = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
        auto diff          = prog.add_instruction(op::sub{}, on_val, off_val);
        auto unsq_off_val  = prog.add_instruction(op::multibroadcast{lens}, off_val);
        auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff);
        auto l_mul         = prog.add_instruction(op::mul{}, tr_out, unsq_diff_val);
        return prog.add_instruction(op::add{}, l_mul, unsq_off_val);
kahmed10's avatar
kahmed10 committed
1887
1888
    }

kahmed10's avatar
kahmed10 committed
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
    instruction_ref
    parse_tile(const std::string&, const node_info&, std::vector<instruction_ref> args)
    {
        migraphx::argument arg_s = args[1]->eval();
        check_arg_empty(arg_s, "PARSE_TILE: dynamic shape is not supported");
        std::vector<std::int64_t> repeats;
        arg_s.visit([&](auto input) { repeats.assign(input.begin(), input.end()); });

        auto l0 = args[0];
        for(int i = 0; i < repeats.size(); i++)
        {
            auto l1 = l0;
            for(int j = 1; j < repeats[i]; j++)
            {
                l0 = prog.add_instruction(op::concat{i}, l0, l1);
            }
        }
        return l0;
    }

Paul's avatar
Paul committed
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1921
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1922
1923
1924
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
1941
1942
    void parse_graph(const onnx::GraphProto& graph)
    {
1943
        for(auto&& f : graph.initializer())
1944
1945
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1946
1947
1948
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1949
1950
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1951
            {
1952
1953
1954
1955
1956
1957
1958
                std::vector<std::size_t> dims;
                if(map_input_dims.count(name) > 0)
                {
                    dims = map_input_dims.at(name);
                }

                shape s            = parse_type(input.type(), dims);
1959
1960
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1961
        }
1962
1963

        for(auto&& node : graph.node())
Paul's avatar
Paul committed
1964
        {
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(input.empty())
                {
                    this->parse_undefined(input);
                }
                if(instructions.count(input) == 0)
                {
                    MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
                                   "\" is unavailable due to unordered nodes!");
                }
                args.push_back(instructions.at(input));
            }

            std::vector<instruction_ref> result;
            std::size_t output_num = static_cast<std::size_t>(node.output().size());
            if(ops.count(node.op_type()) == 0)
            {
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
            }
            else
            {
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
            }

            output_num = std::min<std::size_t>(output_num, result.size());
            std::transform(node.output().begin(),
                           node.output().begin() + output_num,
                           result.begin(),
                           std::inserter(instructions, instructions.end()),
                           [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
1997
        }
Shucai Xiao's avatar
Shucai Xiao committed
1998

1999
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
2000
        auto prog_output = graph.output();
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
2021
2022
    }

Shucai Xiao's avatar
Shucai Xiao committed
2023
    void parse_undefined(const std::string& name)
2024
    {
Shucai Xiao's avatar
Shucai Xiao committed
2025
        auto ins           = prog.add_instruction(op::undefined{});
2026
2027
2028
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
2053
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
2054
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
2055
2056
2057
2058
2059
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
2060
2061
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
2062
2063
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
2064
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
2065
2066
2067
2068
2069
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
2070
2071
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
2072
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
2073
2074
            switch(t.data_type())
            {
2075
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
2076
2077
2078
2079
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
2080
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
2081
2082
2083
2084
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
2085
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
2086
2087
2088
2089
2090
2091
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
2092
2093
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
2094
            MIGRAPHX_THROW("Invalid tensor type");
2095
        }
Paul's avatar
Paul committed
2096
2097
2098
2099
2100
2101
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
2102
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
2103
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
2104
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2105
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
2106
2107
2108
2109
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
2110
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2111
        {
Khalique's avatar
Khalique committed
2112
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2113
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2114
2115
2116
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2117
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2118
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2119
        }
Paul's avatar
Paul committed
2120
2121
2122
2123
2124
2125
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2126
2127
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2128
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2129
2130
    }

Khalique's avatar
Khalique committed
2131
    static literal
2132
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2133
    {
Khalique's avatar
Khalique committed
2134
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2135
        if(dims.empty())
2136
            return literal{{shape_type}, data};
2137
2138
2139
        return literal{{shape_type, dims}, data};
    }

2140
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2141
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2142
2143
    {
        if(dims.empty())
2144
            return literal{{shape_type}, data.begin(), data.end()};
2145
        return literal{{shape_type, dims}, data.begin(), data.end()};
2146
2147
    }

2148
    shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims)
Paul's avatar
Paul committed
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2159
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2160
2161
2162
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
2163
2164
2165
2166
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2167
2168
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2169
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2170
        }
2171
2172
2173
2174
2175
2176

        if(!input_dims.empty())
        {
            return {shape_type, input_dims};
        }

Paul's avatar
Paul committed
2177
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2178
        auto&& tensor_dims = t.tensor_type().shape().dim();
2179
2180
2181
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2182
2183
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2184
                           {
2185
                               if(static_cast<int>(d.dim_value()) <= 0)
2186
2187
2188
                               {
                                   return default_dim_value;
                               }
2189
                               return d.dim_value();
2190
                           }
2191
2192
2193
2194
                           else
                           {
                               return default_dim_value;
                           }
2195
                       });
2196

2197
2198
2199
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2200
2201
        return {shape_type, dims};
    }
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2224
2225
2226

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2227
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2228
2229
2230
2231
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2232
2233
};

Paul Fultz II's avatar
Paul Fultz II committed
2234
template <class... Ts>
2235
program parse_onnx_from(const onnx_options& options, Ts&&... xs)
Paul's avatar
Paul committed
2236
2237
{
    onnx_parser parser;
2238
2239
2240
    parser.map_input_dims    = options.map_input_dims;
    parser.default_dim_value = options.default_dim_value;

Paul's avatar
Paul committed
2241
2242
2243
2244
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
Paul Fultz II's avatar
Paul Fultz II committed
2245
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2246
2247
2248
2249
2250
2251
2252
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
Paul Fultz II's avatar
Paul Fultz II committed
2253
    parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2254
2255
2256
2257
#endif
    return std::move(parser.prog);
}

2258
program parse_onnx(const std::string& name, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2259
2260
2261
2262
2263
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

2264
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2265
2266
2267
2268
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

2269
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2270
2271
2272
2273
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2274
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2275
} // namespace migraphx