onnx.cpp 81.7 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
27
28
29
30
31
32
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
33
    using op_func =
34
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
35
36
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
37
38
39
    program prog            = program();
    bool is_pytorch         = false;
    unsigned int batch_size = 1;
Paul's avatar
Paul committed
40
41

    std::unordered_map<std::string, op_func> ops;
42
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
43
44
45

    onnx_parser()
    {
46
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
47
        add_generic_op("Abs", op::abs{});
48
49
50
51
52
53
54
55
56
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
57
        add_generic_op("Erf", op::erf{});
58
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
59
        add_generic_op("Dropout", op::identity{});
60
61
        add_generic_op("Log", op::log{});
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
62
        add_generic_op("Identity", op::identity{});
kahmed10's avatar
kahmed10 committed
63
        add_generic_op("Reciprocal", op::recip{});
64
65
66
67
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
68
        add_generic_op("Sin", op::sin{});
69
        add_generic_op("Sinh", op::sinh{});
70
        add_generic_op("Sqrt", op::sqrt{});
71
72
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
73

Khalique's avatar
Khalique committed
74
75
76
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
77
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
78
        add_binary_op("PRelu", op::prelu{});
79
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
80

Khalique's avatar
Khalique committed
81
82
83
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
84

85
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
86
87
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
88
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
89
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
90
        add_mem_op("Clip", &onnx_parser::parse_clip);
91
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
92
        add_mem_op("Constant", &onnx_parser::parse_constant);
93
94
95
96
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
97
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
98
99
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
100
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
101
        add_mem_op("Gather", &onnx_parser::parse_gather);
Paul's avatar
Paul committed
102
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
103
104
105
106
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
107
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
108
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
109
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
110
111
112
113
        add_mem_op("LRN", &onnx_parser::parse_lrn);
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
kahmed10's avatar
kahmed10 committed
114
        add_mem_op("OneHot", &onnx_parser::parse_onehot);
Shucai Xiao's avatar
Shucai Xiao committed
115
116
117
118
119
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
120
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
121
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
122
123
124
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
125
126
127
128
129
130
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Pad", &onnx_parser::parse_pad);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
131
        add_mem_op("Split", &onnx_parser::parse_split);
132
133
134
135
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
136
137
138
139
140
141
142

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
143
144
145
146
147
148
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
149
150
151
152
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
153
154
155
156
157
158
159
160
161
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
162
163
164
165
166
167
168
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
169
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
170
171
172
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
173

174
    template <class T>
Khalique's avatar
Khalique committed
175
    void add_binary_op(std::string name, T x)
176
    {
177
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
178
            if(args.size() != 2)
Paul's avatar
Paul committed
179
                MIGRAPHX_THROW("binary operators should have 2 operands");
180
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
181
            {
182
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
183
184
                if(broadcasted != 0)
                {
185
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
186
187
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
188
189
                    return prog.add_instruction(x, args[0], l);
                }
190
                return prog.add_instruction(x, args);
191
            }
Paul's avatar
Paul committed
192
            else
193
            {
Khalique's avatar
Khalique committed
194
                return add_broadcastable_binary_op(args[0], args[1], x);
195
196
197
198
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
199
200
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
201
202
203
204
205
206
207
208
209
210
211
212
213
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
214
        if(s0.size() > s1.size())
215
216
217
218
219
220
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
221
222
223
224
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
225
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
226
                           if(a != b and a != 1 and b != 1)
227
                           {
Shucai Xiao's avatar
Shucai Xiao committed
228
229
230
231
232
233
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
234
235
236
237

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
238
239
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
240
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
241
242
243
244
245
246
247
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
248
249
250
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
251
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
252
253
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
254
255
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
256
            auto out_lens = compute_broadcasted_lens(s0, s1);
257
258
259
260
261
262
263
264
265

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
266
267
268
269
270
271
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
272
273
    }

Paul's avatar
Paul committed
274
    template <class T>
Paul's avatar
Paul committed
275
276
    void add_generic_op(std::string name, T x)
    {
277
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
278
279
280
281
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
282
    template <class T>
Khalique's avatar
Khalique committed
283
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
284
    {
285
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
286
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
287
288
289
290
291
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
292
        });
Khalique's avatar
Khalique committed
293
294
    }

kahmed10's avatar
kahmed10 committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

314
315
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
316
                            const std::vector<int64_t>& padding,
317
318
319
320
321
                            Op& op,
                            float pad_val = 0)
    {
        if(padding[0] != padding[2] || padding[1] != padding[3])
        {
322
323
324
            ins = prog.add_instruction(
                op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
                ins);
325
326
327
328
329
330
331
332
        }
        else
        {
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
        }
    }

333
334
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
335
    {
kahmed10's avatar
kahmed10 committed
336
337
338
339
340
341
342
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
343
        {
kahmed10's avatar
kahmed10 committed
344
345
346
347
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
348
        }
kahmed10's avatar
kahmed10 committed
349
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
350
        {
kahmed10's avatar
kahmed10 committed
351
352
353
354
355
356
357
358
359
360
361
362
363
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
364
        }
kahmed10's avatar
kahmed10 committed
365
366
367
368
369
370
371
372
373
374
375
376
377

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
378
379
    }

Shucai Xiao's avatar
Shucai Xiao committed
380
    template <class Op>
381
382
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
383
    {
384
        int64_t axis = 1;
385
        if(contains(info.attributes, "axis"))
386
        {
387
            axis = parse_value(info.attributes.at("axis")).at<int>();
388
389
        }

390
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
391
392
    }

Shucai Xiao's avatar
Shucai Xiao committed
393
    template <class Op>
394
395
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
396
    {
397
        int64_t axis = 0;
398
        if(contains(info.attributes, "axis"))
399
        {
400
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
401
402
        }

Shucai Xiao's avatar
Shucai Xiao committed
403
        int keep_dims = 1;
404
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
405
        {
406
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
407
408
        }

Shucai Xiao's avatar
Shucai Xiao committed
409
        if(keep_dims == 0)
410
        {
411
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
412
            return prog.add_instruction(op::squeeze{{axis}}, ins);
413
414
415
        }
        else
        {
416
            return prog.add_instruction(Op{axis}, std::move(args));
417
        }
418
419
    }

420
421
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
422
                                               node_info info,
423
                                               Op& op,
424
425
426
427
                                               std::array<std::size_t, 2> k_lens,
                                               std::array<std::size_t, 2> dilation,
                                               const std::vector<std::size_t>& in_lens,
                                               float value = 0.0f)
428
    {
429
        if(!contains(info.attributes, "auto_pad"))
430
431
432
433
        {
            return ins;
        }

434
        auto auto_pad = info.attributes["auto_pad"].s();
435
436
        if(auto_pad.find("SAME") != std::string::npos)
        {
437
438
439
440
441
442
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            std::vector<int64_t> padding(in_lens.size());
            calculate_padding(
                0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
            calculate_padding(
                1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
443

444
            check_asym_padding(ins, padding, op, value);
445
446
447
448
449
        }

        return ins;
    }

450
    template <class Op>
Paul's avatar
Paul committed
451
    instruction_ref
452
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
453
    {
454
        Op op;
455
456
        auto l0      = args[0];
        auto weights = args[1];
457
        std::vector<int64_t> padding;
458
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
459
        {
460
            if(contains(info.attributes, "auto_pad"))
461
            {
462
463
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
464
                {
465
466
                    MIGRAPHX_THROW(
                        "PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
467
                }
468
            }
469
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
470
            if(padding.size() != 4)
471
            {
472
                MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
473
            }
474
            check_asym_padding(l0, padding, op);
Paul's avatar
Paul committed
475
        }
476
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
477
        {
478
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
479
        }
480
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
481
        {
482
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
483
        }
484
        if(contains(info.attributes, "auto_pad"))
485
        {
486
            auto s = info.attributes["auto_pad"].s();
wsttiger's avatar
fixes  
wsttiger committed
487
            if(s.find("SAME") != std::string::npos)
488
            {
489
490
491
492
493
494
                op.padding_mode                 = op::padding_mode_t::same;
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
495
                padding.resize(input_dims.size());
496
497
498
499
500
501
                calculate_padding(
                    0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(
                    1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                check_asym_padding(l0, padding, op);
502
            }
503
504
505
506
507

            auto in_lens                      = args[0]->get_shape().lens();
            auto weight_lens                  = args[1]->get_shape().lens();
            std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
            l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
508
        }
509
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
510
        {
511
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
512
        }
kahmed10's avatar
kahmed10 committed
513
514
515
516
517

        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

518
519
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
520
521
522
523
524
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
        bool asymm_padding = false;
525
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
526
        {
527
            if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
528
            {
529
530
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
531
532
533
534
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
            }
535
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
            if(padding.size() != 4)
            {
                MIGRAPHX_THROW("padding should have 4 values");
            }
            if(padding[0] != padding[2] || padding[1] != padding[3])
            {
                asymm_padding = true;
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
550
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
551
        {
552
            copy(info.attributes["strides"].ints(), op.stride.begin());
kahmed10's avatar
kahmed10 committed
553
        }
554
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
555
        {
556
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
557
        }
558
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
559
        {
560
561
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
562
563
564
565
566
567
568
569
570
571
            {
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

572
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
573
        {
574
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
        }

        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
        std::vector<int64_t> curr_shape{dims[2], dims[3]};
        if(asymm_padding)
        {
            op::slice slice_op;
            slice_op.axes   = {0, 1, 2, 3};
            slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
            slice_op.ends   = {
                dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};

            l1 = prog.add_instruction(slice_op, l1);
        }

591
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
592
593
        {
            std::vector<int64_t> output_padding;
594
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
595
596
597
598
            output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
            l1             = prog.add_instruction(op::pad{output_padding}, l1);
        }

599
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
600
601
        {
            std::vector<int64_t> output_shape;
602
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
            dims       = to_int64_vector(l1->get_shape().lens());
            curr_shape = {dims[2], dims[3]};
            if(curr_shape != output_shape)
            {
                std::vector<int64_t> target_padding = {0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       output_shape[0] - curr_shape[0],
                                                       output_shape[1] - curr_shape[1]};
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
620
    }
Paul's avatar
Paul committed
621

622
623
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
624
    {
Khalique's avatar
Khalique committed
625
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
626
        auto l0 = args[0];
Khalique's avatar
Khalique committed
627
        if(starts_with(name, "Global"))
628
        {
Khalique's avatar
Khalique committed
629
630
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
631
        }
632

633
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
634
        {
635
            if(contains(info.attributes, "auto_pad"))
636
            {
637
                auto s = info.attributes["auto_pad"].s();
638
639
640
641
642
643
644
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

645
            std::vector<std::int64_t> padding;
646
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
647
            if(padding.size() != 4)
648
            {
649
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
650
            }
651
652
653
654
            float pad_val = 0;
            if(op.mode == "max")
                pad_val = std::numeric_limits<float>::lowest();
            check_asym_padding(l0, padding, op, pad_val);
Paul's avatar
Paul committed
655
        }
656

657
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
658
        {
659
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
660
        }
661
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
662
        {
663
            copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
Paul's avatar
Paul committed
664
        }
665

666
        if(contains(info.attributes, "auto_pad"))
667
        {
668
669
670
671
672
673
            auto s = info.attributes["auto_pad"].s();
            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }

674
            auto in_lens = args[0]->get_shape().lens();
675
676
677
678
679
680
681
682
            float val    = 0.0f;
            // MaxPool
            if(op.mode == "max")
            {
                val = std::numeric_limits<float>::lowest();
            }

            l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
683
684
        }

685
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
686
687
    }

Paul's avatar
Paul committed
688
    instruction_ref
689
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
690
    {
691
        op::reshape op;
Paul's avatar
Paul committed
692
693
        if(args.size() == 1)
        {
694
            literal s = parse_value(info.attributes.at("shape"));
695
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
696
697
698
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
699
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
700
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
701
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
702
        }
703

Shucai Xiao's avatar
Shucai Xiao committed
704
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
705
706
    }

Paul's avatar
Paul committed
707
    instruction_ref
708
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
709
    {
710
        int64_t axis = 1;
711
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
712
        {
713
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
714
        }
715
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
716
717
    }

718
    instruction_ref
719
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
720
721
    {
        op::squeeze op;
722
        literal s = parse_value(info.attributes.at("axes"));
723
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
724
        return prog.add_instruction(op, make_contiguous(args[0]));
725
726
727
    }

    instruction_ref
728
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
729
730
    {
        op::unsqueeze op;
731
        literal s = parse_value(info.attributes.at("axes"));
732
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
733
        return prog.add_instruction(op, make_contiguous(args[0]));
734
735
    }

Scott Thornton's avatar
Scott Thornton committed
736
    instruction_ref
737
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
738
    {
Shucai Xiao's avatar
Shucai Xiao committed
739
        // change to hande axis to be negative values
740
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
741
742
743
744
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

745
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
746
747
748
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
749

750
    instruction_ref
751
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
752
    {
753
        int axis = 0;
754
        if(contains(info.attributes, "axis"))
755
        {
756
            axis = parse_value(info.attributes.at("axis")).at<int>();
757
        }
758

759
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
760
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
761
762
    }

763
    instruction_ref
764
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
765
766
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
789
        {
790
            literal s = parse_value(info.attributes.at("axes"));
791
792
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
793
794

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
795
        {
Shucai Xiao's avatar
Shucai Xiao committed
796
797
798
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
799
        }
Shucai Xiao's avatar
Shucai Xiao committed
800
        else if(contains(info.attributes, "ends"))
801
        {
802
803
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
804
        }
Shucai Xiao's avatar
Shucai Xiao committed
805
806
807
808
809
810
811
812

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
813
        {
814
            literal s = parse_value(info.attributes.at("starts"));
815
816
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
817

kahmed10's avatar
kahmed10 committed
818
819
820
821
822
823
824
        if(op.axes.empty())
        {
            std::vector<int64_t> axes(args[0]->get_shape().lens().size());
            std::iota(axes.begin(), axes.end(), int64_t{0});
            op.axes = axes;
        }

825
826
827
        return prog.add_instruction(op, args[0]);
    }

828
829
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
830
    {
831
        literal v = parse_value(info.attributes.at("value"));
832
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
833
        if(v.get_shape().elements() == 0)
834
835
836
837
        {
            return prog.add_literal(literal{});
        }

838
        auto dim_size = info.attributes.at("value").t().dims_size();
839
840
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
841
        {
842
            migraphx::shape scalar_shape{v.get_shape().type()};
843
844
845
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
846
847
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
848

Paul's avatar
Paul committed
849
    instruction_ref
850
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
851
852
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
853
        float beta  = 1.0f;
Paul's avatar
Paul committed
854
855
        bool transa = false;
        bool transb = false;
856
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
857
        {
858
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
859
        }
860
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
861
        {
862
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
863
        }
864
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
865
        {
866
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
867
        }
868
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
869
        {
870
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
871
        }
872
873
874
875
876
877

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

878
879
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
880
881
        if(args.size() == 3)
        {
882
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
883
            {
Shucai Xiao's avatar
Shucai Xiao committed
884
                auto out_lens   = l1->get_shape().lens();
885
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
886
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
887
888
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
889
                {
890
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
891
                }
892
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
893
            }
Paul's avatar
Paul committed
894
        }
895
896

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
897
898
    }

899
    template <class Op>
900
    instruction_ref
901
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
902
    {
Shucai Xiao's avatar
Shucai Xiao committed
903
904
        auto l0      = args[0];
        auto l1      = args[1];
905
906
907
908
909
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
910
        if(l0_lens.size() == 1)
911
912
913
914
915
916
917
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
918
        if(l1_lens.size() == 1)
919
920
921
922
923
924
925
926
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
927
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
928
929
930
931
932
933
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
934
            l0_broadcasted_lens = output_lens;
935
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
936
            l1_broadcasted_lens = output_lens;
937
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
938
            if(l0_lens != l0_broadcasted_lens)
939
940
941
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
942
            if(l1_lens != l1_broadcasted_lens)
943
944
945
946
947
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

948
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
949
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
950
        if(is_a_prepended)
951
952
953
954
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
955
        if(is_b_appended)
956
957
958
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
959

960
961
962
        return dot_res;
    }

963
    instruction_ref
964
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
965
    {
Scott Thornton's avatar
Scott Thornton committed
966
967
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
968
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
969
        if(contains(info.attributes, "epsilon"))
970
        {
971
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
972
        }
973
        if(contains(info.attributes, "momentum"))
974
        {
975
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
976
        }
977
        if(contains(info.attributes, "spatial"))
978
        {
979
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
980
981
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
982
        }
Paul's avatar
Paul committed
983
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
984
        return prog.add_instruction(op, std::move(args));
985
986
    }

987
988
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
989
990
991
992
993
994
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
995
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
996
        {
997
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1022
1023
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1024
    {
Khalique's avatar
Khalique committed
1025
        float alpha = 0.01; // default alpha val for leaky relu
1026
        if(contains(info.attributes, "alpha"))
1027
        {
1028
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1029
1030
1031
1032
1033
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1034
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1035
1036
    {
        float alpha = 1.0; // default alpha val for elu
1037
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1038
        {
1039
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1040
1041
1042
1043
1044
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1045
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1046
1047
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1048
1049
1050
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1051
1052
1053
1054
1055
1056
1057
1058
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1059
1060
1061
1062
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1063
1064
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1065
1066
1067
    {
        float scale = 1.0;
        std::vector<float> bias{};
1068
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1069
        {
1070
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1071
1072
        }

1073
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1074
        {
1075
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1076
1077
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1078
1079
1080
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1081

Shucai Xiao's avatar
Shucai Xiao committed
1082
1083
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1084

1085
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1086
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1087
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1088
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1089
    }
Khalique's avatar
Khalique committed
1090

Khalique's avatar
Khalique committed
1091
    instruction_ref
1092
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1093
1094
    {
        std::vector<int64_t> perm{};
1095
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1096
        {
1097
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1098
1099
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1100
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1101
1102
    }

1103
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1104
1105
    {
        std::vector<int64_t> pads{};
1106
1107
1108
1109
1110
1111
1112
        if(args.size() >= 2)
        {
            auto pad_arg = args.at(1)->eval();
            check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
            pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
        }
        else if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1113
        {
1114
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1115
1116
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1117
1118
1119
1120
1121
        else
        {
            MIGRAPHX_THROW("PARSE_PAD: pad must be available");
        }

1122
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1123
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1124
1125
1126
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144

        float value = 0.0f;
        // third input is the value
        if(args.size() == 3)
        {
            auto val_ins = args.at(2);
            if(!val_ins->can_eval())
            {
                MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
            }
            auto val_arg = val_ins->eval();
            if(val_arg.get_shape().elements() != 1)
            {
                MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
            }
            value = val_arg.at<float>();
        }
        else if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1145
        {
1146
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1147
        }
1148

1149
        if(contains(info.attributes, "mode"))
Khalique's avatar
Khalique committed
1150
        {
1151
            auto mode = info.attributes.at("mode").s();
Khalique's avatar
Khalique committed
1152
            if(mode != "constant")
1153
1154
1155
            {
                MIGRAPHX_THROW("PARSE_PAD: migraphx currently only supports constant padding");
            }
Khalique's avatar
Khalique committed
1156
1157
1158
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1159
1160
1161
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1162
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1163
1164
    {
        if(args.size() != 1)
1165
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1178
1179
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1180
1181
1182
1183
1184
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1185
        if(contains(info.attributes, "dtype"))
1186
        {
1187
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1188
        }
Shucai Xiao's avatar
Shucai Xiao committed
1189
        shape::type_t type = get_type(dtype);
1190

1191
        if(contains(info.attributes, "input_as_shape"))
1192
        {
1193
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1194
1195
        }

1196
        if(contains(info.attributes, "value"))
1197
        {
1198
            value = parse_value(info.attributes.at("value")).at<float>();
1199
1200
        }

1201
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1202
        {
1203
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1204
1205
        }

1206
1207
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1208
            if(args.size() != 1)
1209
            {
1210
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1211
1212
            }

1213
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1214
            {
1215
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1216
                               "at the same time");
1217
1218
            }

1219
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1220
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1221

1222
1223
1224
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1225
1226
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1227
1228
1229
        }
        else if(input_as_shape == 0)
        {
1230
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1231
            {
1232
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1233
1234
            }

1235
            literal ls = parse_value(info.attributes.at("shape"));
1236
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1237
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1238
            migraphx::shape s{type, dims};
1239
1240
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1241
1242
1243
        }
        else
        {
1244
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1245
1246
1247
        }
    }

1248
1249
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1250
1251
    {
        literal l_val{};
1252
        if(contains(info.attributes, "value"))
1253
        {
1254
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1255
            if(l_val.get_shape().elements() != 1)
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1267

Shucai Xiao's avatar
Shucai Xiao committed
1268
        if(args.empty())
1269
        {
Shucai Xiao's avatar
Shucai Xiao committed
1270
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1271
1272
1273
        }
        else
        {
1274
1275
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1276
            if(args[0]->get_shape().elements() == 0)
1277
            {
1278
                s = migraphx::shape{type, {1}, {0}};
1279
            }
1280
1281
1282
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1283
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1284

1285
1286
1287
1288
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1289

Shucai Xiao's avatar
Shucai Xiao committed
1290
            literal l_out{};
1291
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1292
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1293
                // l_val contains only one element
1294
                std::vector<val_type> out_vec(s.elements(), val.front());
1295
1296
1297
1298
1299
1300
1301
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1302
    instruction_ref
1303
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1304
    {
Shucai Xiao's avatar
Shucai Xiao committed
1305
        auto in_lens             = args[0]->get_shape().lens();
1306
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1307
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1308
1309
1310
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1311
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1312
1313
    }

Shucai Xiao's avatar
Shucai Xiao committed
1314
    std::vector<instruction_ref>
1315
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1316
1317
    {
        migraphx::shape input_shape = args[0]->get_shape();
1318
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1319

1320
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1321
        {
1322
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1323
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1324
1325
1326
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1327
1328
1329
1330
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1331
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1332
        {
1333
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1334
1335
        }

1336
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1337
1338
        if(direction == "bidirectional")
        {
1339
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1340
1341
1342
        }
        else if(direction == "reverse")
        {
1343
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1344
1345
        }

1346
        std::vector<std::string> vec_names{"tanh"};
1347
        if(contains(info.attributes, "activations"))
1348
        {
1349
            auto names = info.attributes.at("activations").strings();
1350
            vec_names.clear();
1351
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1352
1353
1354
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1355
1356
        }

1357
1358
1359
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1360
        if(name_it != vec_names.end())
1361
1362
1363
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1364

Shucai Xiao's avatar
Shucai Xiao committed
1365
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1366
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1367
        // if only one actv function is provided, we use it in both
1368
        // forward and reverse direction
1369
        if(dirct == op::rnn_direction::bidirectional)
1370
        {
Shucai Xiao's avatar
Shucai Xiao committed
1371
            if(vec_names.size() == 1)
1372
1373
1374
1375
1376
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1377
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1378
1379
1380
1381
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1382

Shucai Xiao's avatar
Shucai Xiao committed
1383
1384
        // To be added later
        float clip = 0.0;
1385
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1386
        {
1387
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1388
1389
        }

1390
1391
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1392
        if(args.size() < 6)
1393
1394
1395
1396
1397
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1398
1399
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1400
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1401

1402
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1403
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1404

Shucai Xiao's avatar
Shucai Xiao committed
1405
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1406
1407
    }

1408
    std::vector<instruction_ref>
1409
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1410
1411
1412
1413
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1414
        if(contains(info.attributes, "hidden_size"))
1415
        {
1416
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1417
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1418
1419
1420
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1421
1422
1423
1424
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1425
        if(contains(info.attributes, "direction"))
1426
        {
1427
            direction = info.attributes.at("direction").s();
1428
1429
        }

1430
        op::rnn_direction dirct = op::rnn_direction::forward;
1431
1432
        if(direction == "bidirectional")
        {
1433
            dirct = op::rnn_direction::bidirectional;
1434
1435
1436
        }
        else if(direction == "reverse")
        {
1437
            dirct = op::rnn_direction::reverse;
1438
1439
        }

1440
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1441
        if(contains(info.attributes, "activations"))
1442
        {
1443
            auto names = info.attributes.at("activations").strings();
1444
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1445
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1446
1447
1448
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1449
1450
        }

1451
        // need 4 activation functions
1452
        if(dirct == op::rnn_direction::bidirectional)
1453
        {
Shucai Xiao's avatar
Shucai Xiao committed
1454
            // 4 activation functions are used in the bidirectional
1455
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1456
1457
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1458
1459
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1460
1461
1462
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1463
            if(vec_names.size() == 1)
1464
            {
1465
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1466
            }
1467
            else if(vec_names.size() == 2)
1468
            {
1469
1470
1471
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1472
            }
1473
            else if(vec_names.size() == 3)
1474
            {
1475
                vec_names.push_back(vec_names.at(2));
1476
1477
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1478
        else
1479
        {
1480
            if(vec_names.size() == 1)
1481
            {
1482
                vec_names.push_back(vec_names.at(0));
1483
1484
1485
            }
        }

1486
1487
1488
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1489
        if(name_it != vec_names.end())
1490
1491
1492
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1493

Shucai Xiao's avatar
Shucai Xiao committed
1494
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1495
1496
1497
1498
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1499
1500

        float clip = 0.0;
1501
        if(contains(info.attributes, "clip"))
1502
        {
1503
            clip = parse_value(info.attributes.at("clip")).at<float>();
1504
1505
1506
        }

        int linear_before_reset = 0;
1507
        if(contains(info.attributes, "linear_before_reset"))
1508
        {
1509
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1510
1511
        }

Shucai Xiao's avatar
Shucai Xiao committed
1512
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1513
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1514
1515
1516
1517
1518
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1519
1520
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1521
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1522
            std::move(args));
1523
1524

        // second output for last gru output
1525
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1526

Shucai Xiao's avatar
Shucai Xiao committed
1527
        return {hidden_states, last_output};
1528
1529
    }

Shucai Xiao's avatar
Shucai Xiao committed
1530
    std::vector<instruction_ref>
1531
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1532
1533
1534
1535
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1536
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1537
        {
1538
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1539
1540
1541
1542
1543
1544
1545
1546
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1547
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1548
        {
1549
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1550
1551
        }

Shucai Xiao's avatar
Shucai Xiao committed
1552
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1553
1554
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1555
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1556
1557
1558
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1559
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1560
        }
Shucai Xiao's avatar
Shucai Xiao committed
1561
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1562
        {
Shucai Xiao's avatar
Shucai Xiao committed
1563
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1564
1565
1566
1567
1568
1569
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1570
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1571
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1572
        {
1573
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1574
1575
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1576
1577
1578
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1579
1580
1581
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1582
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1583
1584
1585
1586
1587
1588
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1589
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1590
1591
1592
1593
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1594
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1595
1596
1597
1598
1599
1600
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1601
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1602
1603
1604

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1605
1606
1607
1608
1609
1610
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1611
1612
1613
1614
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1615
1616
1617
1618
1619
1620
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1621
1622
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1623
1624
1625
1626
1627
1628
1629
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1630
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1631

Shucai Xiao's avatar
Shucai Xiao committed
1632
1633
1634
1635
1636
1637
1638
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1639
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1640

Shucai Xiao's avatar
Shucai Xiao committed
1641
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1642
1643
1644
1645
1646
1647
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1648
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1649
1650
1651

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1652
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1653
1654
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1655
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1656
1657
1658
            }
        }

1659
1660
1661
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1662
        if(name_it != vec_names.end())
1663
1664
1665
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1666
1667

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1668
1669
1670
1671
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1672
1673

        float clip = 0.0;
1674
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1675
        {
1676
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1677
1678
1679
        }

        int input_forget = 0;
1680
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
1681
        {
1682
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1683
1684
1685
1686
1687
1688
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1689
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1690
1691
1692
1693
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1694
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1695
1696

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1697
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1698
1699
1700
1701
1702
1703

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1704

Shucai Xiao's avatar
Shucai Xiao committed
1705
    template <class T>
1706
1707
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1708
1709
1710
1711
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1712
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1713
        std::iota(axes.begin(), axes.end(), 0);
1714
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
1715
1716
        {
            axes.clear();
1717
            auto&& attr_axes = info.attributes["axes"].ints();
1718
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1719
1720
1721
        }

        int keep_dims = 1;
1722
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
1723
        {
1724
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1725
1726
1727
1728
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1729
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1730
1731
1732
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1733
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1734
            return prog.add_instruction(op::squeeze{axes}, ins);
1735
1736
        }
    }
1737

Shucai Xiao's avatar
Shucai Xiao committed
1738
    instruction_ref
1739
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1740
1741
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
1742
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1743
1744
1745
    }

    instruction_ref
1746
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1747
1748
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1749
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1750
1751
1752
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

1753
1754
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1755
    {
1756
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1757
1758
1759
        return prog.add_instruction(op::log{}, sum_ins);
    }

1760
1761
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1762
1763
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
1764
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1765
1766
1767
        return prog.add_instruction(op::log{}, sum_ins);
    }

1768
1769
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1770
1771
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1772
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1773
1774
    }

Shucai Xiao's avatar
Shucai Xiao committed
1775
    instruction_ref
1776
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
1777
    {
1778
        if(!contains(info.attributes, "to"))
1779
1780
1781
1782
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

1783
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
1784
1785
1786
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1787

1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

kahmed10's avatar
kahmed10 committed
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
    instruction_ref
    parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        migraphx::argument depth_arg = args[1]->eval();
        check_arg_empty(depth_arg, "ONEHOT: depth - dynamic shape not supported");
        size_t depth = depth_arg.at<size_t>();

        int64_t axis = -1;
        std::vector<float> on_off_vals;

        migraphx::argument values_arg = args[2]->eval();
        check_arg_empty(values_arg, "ONEHOT: values - dynamic shape not supported");
        values_arg.visit([&](auto v) { copy(v, std::back_inserter(on_off_vals)); });
        float off_value = on_off_vals[0];
        float on_value  = on_off_vals[1];

        std::vector<float> depth_input(depth * depth, off_value);
        for(int i = 0; i < depth; i++)
        {
            depth_input[depth * i + i] = on_value;
        }

        if(contains(info.attributes, "axis"))
            axis = info.attributes.at("axis").i();
        if(axis == -1)
        {
            shape s{shape::float_type, {depth, depth}};
            auto l0 = prog.add_literal({s, depth_input});
            return prog.add_instruction(op::gather{0}, {l0, args[0]});
        }
        MIGRAPHX_THROW("ONEHOT: MIGraphX does not support axis != -1");
    }

Paul's avatar
Paul committed
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1886
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1887
1888
1889
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
1906
1907
    void parse_graph(const onnx::GraphProto& graph)
    {
1908
        for(auto&& f : graph.initializer())
1909
1910
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1911
1912
1913
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1914
1915
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1916
1917
            {
                // TODO: Get shape of input parameter
1918
                shape s            = parse_type(input.type(), batch_size);
1919
1920
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1921
        }
1922
1923

        for(auto&& node : graph.node())
Paul's avatar
Paul committed
1924
        {
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(input.empty())
                {
                    this->parse_undefined(input);
                }
                if(instructions.count(input) == 0)
                {
                    MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
                                   "\" is unavailable due to unordered nodes!");
                }
                args.push_back(instructions.at(input));
            }

            std::vector<instruction_ref> result;
            std::size_t output_num = static_cast<std::size_t>(node.output().size());
            if(ops.count(node.op_type()) == 0)
            {
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
            }
            else
            {
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
            }

            output_num = std::min<std::size_t>(output_num, result.size());
            std::transform(node.output().begin(),
                           node.output().begin() + output_num,
                           result.begin(),
                           std::inserter(instructions, instructions.end()),
                           [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
1957
        }
Shucai Xiao's avatar
Shucai Xiao committed
1958

1959
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
1960
        auto prog_output = graph.output();
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
1981
1982
    }

Shucai Xiao's avatar
Shucai Xiao committed
1983
    void parse_undefined(const std::string& name)
1984
    {
Shucai Xiao's avatar
Shucai Xiao committed
1985
        auto ins           = prog.add_instruction(op::undefined{});
1986
1987
1988
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
2013
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
2014
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
2015
2016
2017
2018
2019
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
2020
2021
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
2022
2023
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
2024
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
2025
2026
2027
2028
2029
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
2030
2031
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
2032
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
2033
2034
            switch(t.data_type())
            {
2035
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
2036
2037
2038
2039
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
2040
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
2041
2042
2043
2044
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
2045
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
2046
2047
2048
2049
2050
2051
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
2052
2053
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
2054
            MIGRAPHX_THROW("Invalid tensor type");
2055
        }
Paul's avatar
Paul committed
2056
2057
2058
2059
2060
2061
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
2062
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
2063
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
2064
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2065
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
2066
2067
2068
2069
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
2070
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2071
        {
Khalique's avatar
Khalique committed
2072
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2073
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2074
2075
2076
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2077
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2078
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2079
        }
Paul's avatar
Paul committed
2080
2081
2082
2083
2084
2085
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2086
2087
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2088
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2089
2090
    }

Khalique's avatar
Khalique committed
2091
    static literal
2092
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2093
    {
Khalique's avatar
Khalique committed
2094
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2095
        if(dims.empty())
2096
            return literal{{shape_type}, data};
2097
2098
2099
        return literal{{shape_type, dims}, data};
    }

2100
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2101
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2102
2103
    {
        if(dims.empty())
2104
            return literal{{shape_type}, data.begin(), data.end()};
2105
        return literal{{shape_type, dims}, data.begin(), data.end()};
2106
2107
    }

2108
    static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
Paul's avatar
Paul committed
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2119
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2120
2121
2122
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
2123
2124
2125
2126
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2127
2128
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2129
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2130
2131
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2132
        auto&& tensor_dims = t.tensor_type().shape().dim();
2133
2134
2135
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2136
2137
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2138
                           {
2139
2140
2141
                               if(static_cast<int>(d.dim_value()) <= 0)
                                   return batch_size;
                               return d.dim_value();
2142
                           }
2143
                           return batch_size;
2144
                       });
2145
2146
2147
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2148
2149
        return {shape_type, dims};
    }
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2172
2173
2174

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2175
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2176
2177
2178
2179
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2180
2181
};

Paul Fultz II's avatar
Paul Fultz II committed
2182
2183
template <class... Ts>
program parse_onnx_from(onnx_options options, Ts&&... xs)
Paul's avatar
Paul committed
2184
2185
{
    onnx_parser parser;
2186
    parser.batch_size = options.batch_size;
Paul's avatar
Paul committed
2187
2188
2189
2190
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
Paul Fultz II's avatar
Paul Fultz II committed
2191
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2192
2193
2194
2195
2196
2197
2198
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
Paul Fultz II's avatar
Paul Fultz II committed
2199
    parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2200
2201
2202
2203
#endif
    return std::move(parser.prog);
}

Paul Fultz II's avatar
Paul Fultz II committed
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
program parse_onnx(const std::string& name, onnx_options options)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

program parse_onnx_buffer(const std::string& buffer, onnx_options options)
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options)
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2220
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2221
} // namespace migraphx