onnx.cpp 79.5 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
27
28
29
30
31
32
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
33
    using op_func =
34
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
35
36
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
37
38
39
    program prog            = program();
    bool is_pytorch         = false;
    unsigned int batch_size = 1;
Paul's avatar
Paul committed
40
41

    std::unordered_map<std::string, op_func> ops;
42
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
43
44
45

    onnx_parser()
    {
46
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
47
        add_generic_op("Abs", op::abs{});
48
49
50
51
52
53
54
55
56
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
57
        add_generic_op("Erf", op::erf{});
58
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
59
        add_generic_op("Dropout", op::identity{});
60
61
        add_generic_op("Log", op::log{});
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
62
        add_generic_op("Identity", op::identity{});
kahmed10's avatar
kahmed10 committed
63
        add_generic_op("Reciprocal", op::recip{});
64
65
66
67
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
68
        add_generic_op("Sin", op::sin{});
69
        add_generic_op("Sinh", op::sinh{});
70
        add_generic_op("Sqrt", op::sqrt{});
71
72
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
73

Khalique's avatar
Khalique committed
74
75
76
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
77
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
78
        add_binary_op("PRelu", op::prelu{});
79
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
80

Khalique's avatar
Khalique committed
81
82
83
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
84

85
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
86
87
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
88
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
89
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
90
        add_mem_op("Clip", &onnx_parser::parse_clip);
91
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
92
        add_mem_op("Constant", &onnx_parser::parse_constant);
93
94
95
96
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
97
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
98
99
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
100
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
101
        add_mem_op("Gather", &onnx_parser::parse_gather);
Paul's avatar
Paul committed
102
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
103
104
105
106
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
107
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
108
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
109
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
110
111
112
113
        add_mem_op("LRN", &onnx_parser::parse_lrn);
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
Shucai Xiao's avatar
Shucai Xiao committed
114
115
116
117
118
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
119
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
120
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
121
122
123
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
124
125
126
127
128
129
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Pad", &onnx_parser::parse_pad);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
130
        add_mem_op("Split", &onnx_parser::parse_split);
131
132
133
134
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
135
136
137
138
139
140
141

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
142
143
144
145
146
147
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
148
149
150
151
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
152
153
154
155
156
157
158
159
160
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
161
162
163
164
165
166
167
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
168
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
169
170
171
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
172

173
    template <class T>
Khalique's avatar
Khalique committed
174
    void add_binary_op(std::string name, T x)
175
    {
176
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
177
            if(args.size() != 2)
Paul's avatar
Paul committed
178
                MIGRAPHX_THROW("binary operators should have 2 operands");
179
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
180
            {
181
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
182
183
                if(broadcasted != 0)
                {
184
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
185
186
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
187
188
                    return prog.add_instruction(x, args[0], l);
                }
189
                return prog.add_instruction(x, args);
190
            }
Paul's avatar
Paul committed
191
            else
192
            {
Khalique's avatar
Khalique committed
193
                return add_broadcastable_binary_op(args[0], args[1], x);
194
195
196
197
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
198
199
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
200
201
202
203
204
205
206
207
208
209
210
211
212
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
213
        if(s0.size() > s1.size())
214
215
216
217
218
219
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
220
221
222
223
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
224
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
225
                           if(a != b and a != 1 and b != 1)
226
                           {
Shucai Xiao's avatar
Shucai Xiao committed
227
228
229
230
231
232
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
233
234
235
236

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
237
238
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
239
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
240
241
242
243
244
245
246
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
247
248
249
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
250
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
251
252
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
253
254
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
255
            auto out_lens = compute_broadcasted_lens(s0, s1);
256
257
258
259
260
261
262
263
264

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
265
266
267
268
269
270
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
271
272
    }

Paul's avatar
Paul committed
273
    template <class T>
Paul's avatar
Paul committed
274
275
    void add_generic_op(std::string name, T x)
    {
276
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
277
278
279
280
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
281
    template <class T>
Khalique's avatar
Khalique committed
282
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
283
    {
284
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
285
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
286
287
288
289
290
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
291
        });
Khalique's avatar
Khalique committed
292
293
    }

kahmed10's avatar
kahmed10 committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

313
314
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
315
                            const std::vector<int64_t>& padding,
316
317
318
319
320
                            Op& op,
                            float pad_val = 0)
    {
        if(padding[0] != padding[2] || padding[1] != padding[3])
        {
321
322
323
            ins = prog.add_instruction(
                op::pad{{0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]}, pad_val},
                ins);
324
325
326
327
328
329
330
331
        }
        else
        {
            op.padding[0] = padding[0];
            op.padding[1] = padding[1];
        }
    }

332
333
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
334
    {
kahmed10's avatar
kahmed10 committed
335
336
337
338
339
340
341
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
342
        {
kahmed10's avatar
kahmed10 committed
343
344
345
346
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
347
        }
kahmed10's avatar
kahmed10 committed
348
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
349
        {
kahmed10's avatar
kahmed10 committed
350
351
352
353
354
355
356
357
358
359
360
361
362
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
363
        }
kahmed10's avatar
kahmed10 committed
364
365
366
367
368
369
370
371
372
373
374
375
376

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
377
378
    }

Shucai Xiao's avatar
Shucai Xiao committed
379
    template <class Op>
380
381
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
382
    {
383
        int64_t axis = 1;
384
        if(contains(info.attributes, "axis"))
385
        {
386
            axis = parse_value(info.attributes.at("axis")).at<int>();
387
388
        }

389
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
390
391
    }

Shucai Xiao's avatar
Shucai Xiao committed
392
    template <class Op>
393
394
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
395
    {
396
        int64_t axis = 0;
397
        if(contains(info.attributes, "axis"))
398
        {
399
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
400
401
        }

Shucai Xiao's avatar
Shucai Xiao committed
402
        int keep_dims = 1;
403
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
404
        {
405
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
406
407
        }

Shucai Xiao's avatar
Shucai Xiao committed
408
        if(keep_dims == 0)
409
        {
410
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
411
            return prog.add_instruction(op::squeeze{{axis}}, ins);
412
413
414
        }
        else
        {
415
            return prog.add_instruction(Op{axis}, std::move(args));
416
        }
417
418
    }

419
420
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
421
                                               node_info info,
422
                                               Op& op,
423
424
425
426
                                               std::array<std::size_t, 2> k_lens,
                                               std::array<std::size_t, 2> dilation,
                                               const std::vector<std::size_t>& in_lens,
                                               float value = 0.0f)
427
    {
428
        if(!contains(info.attributes, "auto_pad"))
429
430
431
432
        {
            return ins;
        }

433
        auto auto_pad = info.attributes["auto_pad"].s();
434
435
        if(auto_pad.find("SAME") != std::string::npos)
        {
436
437
438
439
440
441
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            std::vector<int64_t> padding(in_lens.size());
            calculate_padding(
                0, padding, in_lens[2], op.stride[0], dilation[0], k_lens[0], is_same_upper);
            calculate_padding(
                1, padding, in_lens[3], op.stride[1], dilation[1], k_lens[1], is_same_upper);
442

443
            check_asym_padding(ins, padding, op, value);
444
445
446
447
448
        }

        return ins;
    }

449
    template <class Op>
Paul's avatar
Paul committed
450
    instruction_ref
451
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
452
    {
453
        Op op;
454
455
        auto l0      = args[0];
        auto weights = args[1];
456
        std::vector<int64_t> padding;
457
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
458
        {
459
            if(contains(info.attributes, "auto_pad"))
460
            {
461
462
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
463
                {
464
465
                    MIGRAPHX_THROW(
                        "PARSE_CONV: auto_pad and padding cannot be specified simultaneously");
466
                }
467
            }
468
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
469
            if(padding.size() != 4)
470
            {
471
                MIGRAPHX_THROW("PARSE_CONV: padding should have 4 values");
472
            }
473
            check_asym_padding(l0, padding, op);
Paul's avatar
Paul committed
474
        }
475
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
476
        {
477
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
478
        }
479
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
480
        {
481
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
482
        }
483
        if(contains(info.attributes, "auto_pad"))
484
        {
485
            auto s = info.attributes["auto_pad"].s();
wsttiger's avatar
fixes  
wsttiger committed
486
            if(s.find("SAME") != std::string::npos)
487
            {
488
489
490
491
492
493
                op.padding_mode                 = op::padding_mode_t::same;
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
494
                padding.resize(input_dims.size());
495
496
497
498
499
500
                calculate_padding(
                    0, padding, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(
                    1, padding, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                check_asym_padding(l0, padding, op);
501
            }
502
503
504
505
506

            auto in_lens                      = args[0]->get_shape().lens();
            auto weight_lens                  = args[1]->get_shape().lens();
            std::array<std::size_t, 2> k_lens = {weight_lens[2], weight_lens[3]};
            l0 = process_auto_pad_attribute(l0, info, op, k_lens, op.dilation, in_lens);
507
        }
508
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
509
        {
510
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
511
        }
kahmed10's avatar
kahmed10 committed
512
513
514
515
516

        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

517
518
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
519
520
521
522
523
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
        bool asymm_padding = false;
524
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
525
        {
526
            if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
527
            {
528
529
                auto s = info.attributes["auto_pad"].s();
                if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
530
531
532
533
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
            }
534
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
535
536
537
538
539
540
541
542
543
544
545
546
547
548
            if(padding.size() != 4)
            {
                MIGRAPHX_THROW("padding should have 4 values");
            }
            if(padding[0] != padding[2] || padding[1] != padding[3])
            {
                asymm_padding = true;
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
549
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
550
        {
551
            copy(info.attributes["strides"].ints(), op.stride.begin());
kahmed10's avatar
kahmed10 committed
552
        }
553
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
554
        {
555
            copy(info.attributes["dilations"].ints(), op.dilation.begin());
Paul's avatar
Paul committed
556
        }
557
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
558
        {
559
560
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
561
562
563
564
565
566
567
568
569
570
            {
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

571
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
572
        {
573
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        }

        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
        std::vector<int64_t> curr_shape{dims[2], dims[3]};
        if(asymm_padding)
        {
            op::slice slice_op;
            slice_op.axes   = {0, 1, 2, 3};
            slice_op.starts = {0, 0, 0 + padding[0], 0 + padding[1]};
            slice_op.ends   = {
                dims[0], dims[1], curr_shape[0] - padding[2], curr_shape[1] - padding[3]};

            l1 = prog.add_instruction(slice_op, l1);
        }

590
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
591
592
        {
            std::vector<int64_t> output_padding;
593
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
594
595
596
597
            output_padding = {0, 0, 0, 0, 0, 0, output_padding[0], output_padding[1]};
            l1             = prog.add_instruction(op::pad{output_padding}, l1);
        }

598
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
599
600
        {
            std::vector<int64_t> output_shape;
601
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            dims       = to_int64_vector(l1->get_shape().lens());
            curr_shape = {dims[2], dims[3]};
            if(curr_shape != output_shape)
            {
                std::vector<int64_t> target_padding = {0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       0,
                                                       output_shape[0] - curr_shape[0],
                                                       output_shape[1] - curr_shape[1]};
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
619
    }
Paul's avatar
Paul committed
620

621
622
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
623
    {
Khalique's avatar
Khalique committed
624
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
625
        auto l0 = args[0];
Khalique's avatar
Khalique committed
626
        if(starts_with(name, "Global"))
627
        {
Khalique's avatar
Khalique committed
628
629
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
630
        }
631

632
        if(contains(info.attributes, "pads"))
Paul's avatar
Paul committed
633
        {
634
            if(contains(info.attributes, "auto_pad"))
635
            {
636
                auto s = info.attributes["auto_pad"].s();
637
638
639
640
641
642
643
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

644
            std::vector<std::int64_t> padding;
645
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
646
            if(padding.size() != 4)
647
            {
648
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
649
            }
650
651
652
653
            float pad_val = 0;
            if(op.mode == "max")
                pad_val = std::numeric_limits<float>::lowest();
            check_asym_padding(l0, padding, op, pad_val);
Paul's avatar
Paul committed
654
        }
655

656
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
657
        {
658
            copy(info.attributes["strides"].ints(), op.stride.begin());
Paul's avatar
Paul committed
659
        }
660
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
661
        {
662
            copy(info.attributes["kernel_shape"].ints(), op.lengths.begin());
Paul's avatar
Paul committed
663
        }
664

665
        if(contains(info.attributes, "auto_pad"))
666
        {
667
668
669
670
671
672
            auto s = info.attributes["auto_pad"].s();
            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }

673
            auto in_lens = args[0]->get_shape().lens();
674
675
676
677
678
679
680
681
            float val    = 0.0f;
            // MaxPool
            if(op.mode == "max")
            {
                val = std::numeric_limits<float>::lowest();
            }

            l0 = process_auto_pad_attribute(l0, info, op, op.lengths, {1, 1}, in_lens, val);
682
683
        }

684
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
685
686
    }

Paul's avatar
Paul committed
687
    instruction_ref
688
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
689
    {
690
        op::reshape op;
Paul's avatar
Paul committed
691
692
        if(args.size() == 1)
        {
693
            literal s = parse_value(info.attributes.at("shape"));
694
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
695
696
697
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
698
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
699
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
700
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
701
        }
702

Shucai Xiao's avatar
Shucai Xiao committed
703
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
704
705
    }

Paul's avatar
Paul committed
706
    instruction_ref
707
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
708
    {
709
        int64_t axis = 1;
710
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
711
        {
712
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
713
        }
714
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
715
716
    }

717
    instruction_ref
718
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
719
720
    {
        op::squeeze op;
721
        literal s = parse_value(info.attributes.at("axes"));
722
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
723
        return prog.add_instruction(op, make_contiguous(args[0]));
724
725
726
    }

    instruction_ref
727
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
728
729
    {
        op::unsqueeze op;
730
        literal s = parse_value(info.attributes.at("axes"));
731
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
732
        return prog.add_instruction(op, make_contiguous(args[0]));
733
734
    }

Scott Thornton's avatar
Scott Thornton committed
735
    instruction_ref
736
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
737
    {
Shucai Xiao's avatar
Shucai Xiao committed
738
        // change to hande axis to be negative values
739
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
740
741
742
743
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

744
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
745
746
747
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
748

749
    instruction_ref
750
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
751
    {
752
        int axis = 0;
753
        if(contains(info.attributes, "axis"))
754
        {
755
            axis = parse_value(info.attributes.at("axis")).at<int>();
756
        }
757

758
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
759
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
760
761
    }

762
    instruction_ref
763
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
764
765
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
788
        {
789
            literal s = parse_value(info.attributes.at("axes"));
790
791
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
792
793

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
794
        {
Shucai Xiao's avatar
Shucai Xiao committed
795
796
797
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
798
        }
Shucai Xiao's avatar
Shucai Xiao committed
799
        else if(contains(info.attributes, "ends"))
800
        {
801
802
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
803
        }
Shucai Xiao's avatar
Shucai Xiao committed
804
805
806
807
808
809
810
811

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
812
        {
813
            literal s = parse_value(info.attributes.at("starts"));
814
815
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
816

kahmed10's avatar
kahmed10 committed
817
818
819
820
821
822
823
        if(op.axes.empty())
        {
            std::vector<int64_t> axes(args[0]->get_shape().lens().size());
            std::iota(axes.begin(), axes.end(), int64_t{0});
            op.axes = axes;
        }

824
825
826
        return prog.add_instruction(op, args[0]);
    }

827
828
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
829
    {
830
        literal v = parse_value(info.attributes.at("value"));
831
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
832
        if(v.get_shape().elements() == 0)
833
834
835
836
        {
            return prog.add_literal(literal{});
        }

837
        auto dim_size = info.attributes.at("value").t().dims_size();
838
839
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
840
        {
841
            migraphx::shape scalar_shape{v.get_shape().type()};
842
843
844
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
845
846
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
847

Paul's avatar
Paul committed
848
    instruction_ref
849
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
850
851
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
852
        float beta  = 1.0f;
Paul's avatar
Paul committed
853
854
        bool transa = false;
        bool transb = false;
855
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
856
        {
857
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
858
        }
859
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
860
        {
861
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
862
        }
863
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
864
        {
865
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
866
        }
867
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
868
        {
869
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
870
        }
871
872
873
874
875
876

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

877
878
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
879
880
        if(args.size() == 3)
        {
881
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
882
            {
Shucai Xiao's avatar
Shucai Xiao committed
883
                auto out_lens   = l1->get_shape().lens();
884
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
885
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
886
887
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
888
                {
889
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
890
                }
891
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
892
            }
Paul's avatar
Paul committed
893
        }
894
895

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
896
897
    }

898
    template <class Op>
899
    instruction_ref
900
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
901
    {
Shucai Xiao's avatar
Shucai Xiao committed
902
903
        auto l0      = args[0];
        auto l1      = args[1];
904
905
906
907
908
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
909
        if(l0_lens.size() == 1)
910
911
912
913
914
915
916
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
917
        if(l1_lens.size() == 1)
918
919
920
921
922
923
924
925
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
926
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
927
928
929
930
931
932
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
933
            l0_broadcasted_lens = output_lens;
934
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
935
            l1_broadcasted_lens = output_lens;
936
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
937
            if(l0_lens != l0_broadcasted_lens)
938
939
940
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
941
            if(l1_lens != l1_broadcasted_lens)
942
943
944
945
946
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

947
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
948
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
949
        if(is_a_prepended)
950
951
952
953
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
954
        if(is_b_appended)
955
956
957
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
958

959
960
961
        return dot_res;
    }

962
    instruction_ref
963
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
964
    {
Scott Thornton's avatar
Scott Thornton committed
965
966
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
967
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
968
        if(contains(info.attributes, "epsilon"))
969
        {
970
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
971
        }
972
        if(contains(info.attributes, "momentum"))
973
        {
974
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
975
        }
976
        if(contains(info.attributes, "spatial"))
977
        {
978
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
979
980
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
981
        }
Paul's avatar
Paul committed
982
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
983
        return prog.add_instruction(op, std::move(args));
984
985
    }

986
987
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
988
989
990
991
992
993
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
994
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
995
        {
996
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1021
1022
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1023
    {
Khalique's avatar
Khalique committed
1024
        float alpha = 0.01; // default alpha val for leaky relu
1025
        if(contains(info.attributes, "alpha"))
1026
        {
1027
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1028
1029
1030
1031
1032
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1033
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1034
1035
    {
        float alpha = 1.0; // default alpha val for elu
1036
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1037
        {
1038
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1039
1040
1041
1042
1043
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1044
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1045
1046
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1047
1048
1049
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1050
1051
1052
1053
1054
1055
1056
1057
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1058
1059
1060
1061
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1062
1063
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1064
1065
1066
    {
        float scale = 1.0;
        std::vector<float> bias{};
1067
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1068
        {
1069
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1070
1071
        }

1072
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1073
        {
1074
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1075
1076
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1077
1078
1079
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1080

Shucai Xiao's avatar
Shucai Xiao committed
1081
1082
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1083

1084
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1085
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1086
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1087
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1088
    }
Khalique's avatar
Khalique committed
1089

Khalique's avatar
Khalique committed
1090
    instruction_ref
1091
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1092
1093
    {
        std::vector<int64_t> perm{};
1094
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1095
        {
1096
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1097
1098
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1099
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1100
1101
    }

1102
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1103
1104
1105
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
1106
        if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1107
        {
1108
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1109
1110
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1111
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1112
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1113
1114
1115
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1116
        if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1117
        {
1118
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1119
        }
1120
        if(contains(info.attributes, "mode"))
Khalique's avatar
Khalique committed
1121
        {
1122
            auto mode = info.attributes.at("mode").s();
Khalique's avatar
Khalique committed
1123
1124
1125
1126
1127
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1128
1129
1130
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1131
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1132
1133
    {
        if(args.size() != 1)
1134
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1147
1148
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1149
1150
1151
1152
1153
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1154
        if(contains(info.attributes, "dtype"))
1155
        {
1156
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1157
        }
Shucai Xiao's avatar
Shucai Xiao committed
1158
        shape::type_t type = get_type(dtype);
1159

1160
        if(contains(info.attributes, "input_as_shape"))
1161
        {
1162
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1163
1164
        }

1165
        if(contains(info.attributes, "value"))
1166
        {
1167
            value = parse_value(info.attributes.at("value")).at<float>();
1168
1169
        }

1170
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1171
        {
1172
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1173
1174
        }

1175
1176
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1177
            if(args.size() != 1)
1178
            {
1179
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1180
1181
            }

1182
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1183
            {
1184
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1185
                               "at the same time");
1186
1187
            }

1188
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1189
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1190

1191
1192
1193
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1194
1195
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1196
1197
1198
        }
        else if(input_as_shape == 0)
        {
1199
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1200
            {
1201
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1202
1203
            }

1204
            literal ls = parse_value(info.attributes.at("shape"));
1205
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1206
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1207
            migraphx::shape s{type, dims};
1208
1209
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1210
1211
1212
        }
        else
        {
1213
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1214
1215
1216
        }
    }

1217
1218
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1219
1220
    {
        literal l_val{};
1221
        if(contains(info.attributes, "value"))
1222
        {
1223
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1224
            if(l_val.get_shape().elements() != 1)
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1236

Shucai Xiao's avatar
Shucai Xiao committed
1237
        if(args.empty())
1238
        {
Shucai Xiao's avatar
Shucai Xiao committed
1239
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1240
1241
1242
        }
        else
        {
1243
1244
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1245
            if(args[0]->get_shape().elements() == 0)
1246
            {
1247
                s = migraphx::shape{type, {1}, {0}};
1248
            }
1249
1250
1251
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1252
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1253

1254
1255
1256
1257
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1258

Shucai Xiao's avatar
Shucai Xiao committed
1259
            literal l_out{};
1260
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1261
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1262
                // l_val contains only one element
1263
                std::vector<val_type> out_vec(s.elements(), val.front());
1264
1265
1266
1267
1268
1269
1270
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1271
    instruction_ref
1272
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1273
    {
Shucai Xiao's avatar
Shucai Xiao committed
1274
        auto in_lens             = args[0]->get_shape().lens();
1275
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1276
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1277
1278
1279
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1280
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1281
1282
    }

Shucai Xiao's avatar
Shucai Xiao committed
1283
    std::vector<instruction_ref>
1284
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1285
1286
    {
        migraphx::shape input_shape = args[0]->get_shape();
1287
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1288

1289
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1290
        {
1291
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1292
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1293
1294
1295
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1296
1297
1298
1299
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1300
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1301
        {
1302
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1303
1304
        }

1305
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1306
1307
        if(direction == "bidirectional")
        {
1308
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1309
1310
1311
        }
        else if(direction == "reverse")
        {
1312
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1313
1314
        }

1315
        std::vector<std::string> vec_names{"tanh"};
1316
        if(contains(info.attributes, "activations"))
1317
        {
1318
            auto names = info.attributes.at("activations").strings();
1319
            vec_names.clear();
1320
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1321
1322
1323
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1324
1325
        }

1326
1327
1328
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1329
        if(name_it != vec_names.end())
1330
1331
1332
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1333

Shucai Xiao's avatar
Shucai Xiao committed
1334
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1335
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1336
        // if only one actv function is provided, we use it in both
1337
        // forward and reverse direction
1338
        if(dirct == op::rnn_direction::bidirectional)
1339
        {
Shucai Xiao's avatar
Shucai Xiao committed
1340
            if(vec_names.size() == 1)
1341
1342
1343
1344
1345
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1346
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1347
1348
1349
1350
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1351

Shucai Xiao's avatar
Shucai Xiao committed
1352
1353
        // To be added later
        float clip = 0.0;
1354
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1355
        {
1356
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1357
1358
        }

1359
1360
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1361
        if(args.size() < 6)
1362
1363
1364
1365
1366
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1367
1368
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1369
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1370

1371
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1372
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1373

Shucai Xiao's avatar
Shucai Xiao committed
1374
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1375
1376
    }

1377
    std::vector<instruction_ref>
1378
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1379
1380
1381
1382
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1383
        if(contains(info.attributes, "hidden_size"))
1384
        {
1385
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1386
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1387
1388
1389
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1390
1391
1392
1393
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1394
        if(contains(info.attributes, "direction"))
1395
        {
1396
            direction = info.attributes.at("direction").s();
1397
1398
        }

1399
        op::rnn_direction dirct = op::rnn_direction::forward;
1400
1401
        if(direction == "bidirectional")
        {
1402
            dirct = op::rnn_direction::bidirectional;
1403
1404
1405
        }
        else if(direction == "reverse")
        {
1406
            dirct = op::rnn_direction::reverse;
1407
1408
        }

1409
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1410
        if(contains(info.attributes, "activations"))
1411
        {
1412
            auto names = info.attributes.at("activations").strings();
1413
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1414
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1415
1416
1417
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1418
1419
        }

1420
        // need 4 activation functions
1421
        if(dirct == op::rnn_direction::bidirectional)
1422
        {
Shucai Xiao's avatar
Shucai Xiao committed
1423
            // 4 activation functions are used in the bidirectional
1424
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1425
1426
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1427
1428
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1429
1430
1431
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1432
            if(vec_names.size() == 1)
1433
            {
1434
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1435
            }
1436
            else if(vec_names.size() == 2)
1437
            {
1438
1439
1440
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1441
            }
1442
            else if(vec_names.size() == 3)
1443
            {
1444
                vec_names.push_back(vec_names.at(2));
1445
1446
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1447
        else
1448
        {
1449
            if(vec_names.size() == 1)
1450
            {
1451
                vec_names.push_back(vec_names.at(0));
1452
1453
1454
            }
        }

1455
1456
1457
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1458
        if(name_it != vec_names.end())
1459
1460
1461
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1462

Shucai Xiao's avatar
Shucai Xiao committed
1463
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1464
1465
1466
1467
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1468
1469

        float clip = 0.0;
1470
        if(contains(info.attributes, "clip"))
1471
        {
1472
            clip = parse_value(info.attributes.at("clip")).at<float>();
1473
1474
1475
        }

        int linear_before_reset = 0;
1476
        if(contains(info.attributes, "linear_before_reset"))
1477
        {
1478
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1479
1480
        }

Shucai Xiao's avatar
Shucai Xiao committed
1481
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1482
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1483
1484
1485
1486
1487
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1488
1489
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1490
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1491
            std::move(args));
1492
1493

        // second output for last gru output
1494
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1495

Shucai Xiao's avatar
Shucai Xiao committed
1496
        return {hidden_states, last_output};
1497
1498
    }

Shucai Xiao's avatar
Shucai Xiao committed
1499
    std::vector<instruction_ref>
1500
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1501
1502
1503
1504
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1505
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1506
        {
1507
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1508
1509
1510
1511
1512
1513
1514
1515
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1516
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1517
        {
1518
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1519
1520
        }

Shucai Xiao's avatar
Shucai Xiao committed
1521
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1522
1523
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1524
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1525
1526
1527
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1528
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1529
        }
Shucai Xiao's avatar
Shucai Xiao committed
1530
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1531
        {
Shucai Xiao's avatar
Shucai Xiao committed
1532
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1533
1534
1535
1536
1537
1538
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1539
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1540
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1541
        {
1542
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1543
1544
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1545
1546
1547
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1548
1549
1550
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1551
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1552
1553
1554
1555
1556
1557
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1558
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1559
1560
1561
1562
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1563
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1564
1565
1566
1567
1568
1569
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1570
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1571
1572
1573

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1574
1575
1576
1577
1578
1579
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1580
1581
1582
1583
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1584
1585
1586
1587
1588
1589
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1590
1591
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1592
1593
1594
1595
1596
1597
1598
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1599
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1600

Shucai Xiao's avatar
Shucai Xiao committed
1601
1602
1603
1604
1605
1606
1607
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1608
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1609

Shucai Xiao's avatar
Shucai Xiao committed
1610
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1611
1612
1613
1614
1615
1616
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1617
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1618
1619
1620

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1621
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1622
1623
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1624
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1625
1626
1627
            }
        }

1628
1629
1630
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1631
        if(name_it != vec_names.end())
1632
1633
1634
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1635
1636

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1637
1638
1639
1640
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1641
1642

        float clip = 0.0;
1643
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1644
        {
1645
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1646
1647
1648
        }

        int input_forget = 0;
1649
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
1650
        {
1651
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1652
1653
1654
1655
1656
1657
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1658
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1659
1660
1661
1662
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1663
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1664
1665

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1666
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1667
1668
1669
1670
1671
1672

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1673

Shucai Xiao's avatar
Shucai Xiao committed
1674
    template <class T>
1675
1676
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1677
1678
1679
1680
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1681
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1682
        std::iota(axes.begin(), axes.end(), 0);
1683
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
1684
1685
        {
            axes.clear();
1686
            auto&& attr_axes = info.attributes["axes"].ints();
1687
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1688
1689
1690
        }

        int keep_dims = 1;
1691
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
1692
        {
1693
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1694
1695
1696
1697
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1698
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1699
1700
1701
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1702
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1703
            return prog.add_instruction(op::squeeze{axes}, ins);
1704
1705
        }
    }
1706

Shucai Xiao's avatar
Shucai Xiao committed
1707
    instruction_ref
1708
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1709
1710
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
1711
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1712
1713
1714
    }

    instruction_ref
1715
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1716
1717
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1718
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1719
1720
1721
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

1722
1723
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1724
    {
1725
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1726
1727
1728
        return prog.add_instruction(op::log{}, sum_ins);
    }

1729
1730
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1731
1732
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
1733
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1734
1735
1736
        return prog.add_instruction(op::log{}, sum_ins);
    }

1737
1738
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1739
1740
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
1741
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
1742
1743
    }

Shucai Xiao's avatar
Shucai Xiao committed
1744
    instruction_ref
1745
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
1746
    {
1747
        if(!contains(info.attributes, "to"))
1748
1749
1750
1751
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

1752
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
1753
1754
1755
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1756

1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

Paul's avatar
Paul committed
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1822
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1823
1824
1825
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
1842
1843
    void parse_graph(const onnx::GraphProto& graph)
    {
1844
        for(auto&& f : graph.initializer())
1845
1846
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1847
1848
1849
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1850
1851
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1852
1853
            {
                // TODO: Get shape of input parameter
1854
                shape s            = parse_type(input.type(), batch_size);
1855
1856
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1857
        }
1858
1859

        for(auto&& node : graph.node())
Paul's avatar
Paul committed
1860
        {
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(input.empty())
                {
                    this->parse_undefined(input);
                }
                if(instructions.count(input) == 0)
                {
                    MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
                                   "\" is unavailable due to unordered nodes!");
                }
                args.push_back(instructions.at(input));
            }

            std::vector<instruction_ref> result;
            std::size_t output_num = static_cast<std::size_t>(node.output().size());
            if(ops.count(node.op_type()) == 0)
            {
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
            }
            else
            {
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
            }

            output_num = std::min<std::size_t>(output_num, result.size());
            std::transform(node.output().begin(),
                           node.output().begin() + output_num,
                           result.begin(),
                           std::inserter(instructions, instructions.end()),
                           [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
1893
        }
Shucai Xiao's avatar
Shucai Xiao committed
1894

1895
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
1896
        auto prog_output = graph.output();
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
1917
1918
    }

Shucai Xiao's avatar
Shucai Xiao committed
1919
    void parse_undefined(const std::string& name)
1920
    {
Shucai Xiao's avatar
Shucai Xiao committed
1921
        auto ins           = prog.add_instruction(op::undefined{});
1922
1923
1924
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
1949
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1950
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
1951
1952
1953
1954
1955
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
1956
1957
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
1958
1959
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
1960
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
1961
1962
1963
1964
1965
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
1966
1967
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
1968
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
1969
1970
            switch(t.data_type())
            {
1971
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
1972
1973
1974
1975
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
1976
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
1977
1978
1979
1980
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
1981
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
1982
1983
1984
1985
1986
1987
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
1988
1989
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
1990
            MIGRAPHX_THROW("Invalid tensor type");
1991
        }
Paul's avatar
Paul committed
1992
1993
1994
1995
1996
1997
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
1998
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
1999
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
2000
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2001
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
2002
2003
2004
2005
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
2006
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2007
        {
Khalique's avatar
Khalique committed
2008
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2009
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2010
2011
2012
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2013
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2014
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2015
        }
Paul's avatar
Paul committed
2016
2017
2018
2019
2020
2021
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2022
2023
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2024
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2025
2026
    }

Khalique's avatar
Khalique committed
2027
    static literal
2028
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2029
    {
Khalique's avatar
Khalique committed
2030
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2031
        if(dims.empty())
2032
            return literal{{shape_type}, data};
2033
2034
2035
        return literal{{shape_type, dims}, data};
    }

2036
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2037
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2038
2039
    {
        if(dims.empty())
2040
            return literal{{shape_type}, data.begin(), data.end()};
2041
        return literal{{shape_type, dims}, data.begin(), data.end()};
2042
2043
    }

2044
    static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
Paul's avatar
Paul committed
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2055
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2056
2057
2058
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
2059
2060
2061
2062
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2063
2064
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2065
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2066
2067
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2068
        auto&& tensor_dims = t.tensor_type().shape().dim();
2069
2070
2071
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2072
2073
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2074
                           {
2075
2076
2077
                               if(static_cast<int>(d.dim_value()) <= 0)
                                   return batch_size;
                               return d.dim_value();
2078
                           }
2079
                           return batch_size;
2080
                       });
2081
2082
2083
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2084
2085
        return {shape_type, dims};
    }
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2108
2109
2110

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2111
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2112
2113
2114
2115
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2116
2117
};

Paul Fultz II's avatar
Paul Fultz II committed
2118
2119
template <class... Ts>
program parse_onnx_from(onnx_options options, Ts&&... xs)
Paul's avatar
Paul committed
2120
2121
{
    onnx_parser parser;
2122
    parser.batch_size = options.batch_size;
Paul's avatar
Paul committed
2123
2124
2125
2126
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
Paul Fultz II's avatar
Paul Fultz II committed
2127
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2128
2129
2130
2131
2132
2133
2134
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
Paul Fultz II's avatar
Paul Fultz II committed
2135
    parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2136
2137
2138
2139
#endif
    return std::move(parser.prog);
}

Paul Fultz II's avatar
Paul Fultz II committed
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
program parse_onnx(const std::string& name, onnx_options options)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

program parse_onnx_buffer(const std::string& buffer, onnx_options options)
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

program parse_onnx_buffer(const void* data, std::size_t size, onnx_options options)
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2156
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2157
} // namespace migraphx