onnx.cpp 64.2 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
Paul's avatar
Paul committed
19
20

namespace migraphx {
Paul's avatar
Paul committed
21
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
22
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
    using node_map      = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
27
28
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
29
30
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
31
32
33
    program prog            = program();
    bool is_pytorch         = false;
    unsigned int batch_size = 1;
Paul's avatar
Paul committed
34
35

    std::unordered_map<std::string, op_func> ops;
36
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
37
38
39

    onnx_parser()
    {
Khalique's avatar
Khalique committed
40
        add_generic_op("Relu", op::relu{});
Khalique's avatar
Khalique committed
41
42
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Abs", op::abs{});
Shucai Xiao's avatar
Shucai Xiao committed
43
        add_generic_op("Exp", op::exp{});
Shucai Xiao's avatar
Shucai Xiao committed
44
        add_generic_op("Erf", op::erf{});
Shucai Xiao's avatar
Shucai Xiao committed
45
        add_generic_op("Log", op::log{});
Khalique's avatar
Khalique committed
46
47
        // disable dropout for inference
        add_generic_op("Dropout", op::identity{});
Khalique's avatar
Khalique committed
48
        add_generic_op("Identity", op::identity{});
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
        add_generic_op("Sin", op::sin{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Tan", op::tan{});
52
53
        add_generic_op("Sinh", op::sinh{});
        add_generic_op("Cosh", op::cosh{});
54
        add_generic_op("Tanh", op::tanh{});
55
56
57
        add_generic_op("Asin", op::asin{});
        add_generic_op("Acos", op::acos{});
        add_generic_op("Atan", op::atan{});
58
        add_generic_op("Sqrt", op::sqrt{});
59
        add_generic_op("Round", op::round{});
60
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
61
62
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Floor", op::floor{});
Paul's avatar
Paul committed
63

Khalique's avatar
Khalique committed
64
65
66
67
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
        add_binary_op("Sub", op::sub{});
Shucai Xiao's avatar
Shucai Xiao committed
68
        add_binary_op("Pow", op::pow{});
Khalique's avatar
Khalique committed
69

Khalique's avatar
Khalique committed
70
71
72
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
73

74
75
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
76
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
77
        add_mem_op("Clip", &onnx_parser::parse_clip);
Khalique's avatar
Khalique committed
78
        add_mem_op("LRN", &onnx_parser::parse_lrn);
Khalique's avatar
Khalique committed
79
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
80
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
Khalique's avatar
Khalique committed
81
        add_mem_op("Elu", &onnx_parser::parse_elu);
82
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
83
84
        add_mem_op("Constant", &onnx_parser::parse_constant);
        add_mem_op("Conv", &onnx_parser::parse_conv);
Paul's avatar
Paul committed
85
86
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
87
88
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
Paul's avatar
Paul committed
89
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
Paul's avatar
Paul committed
90
91
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
92
        add_mem_op("MatMul", &onnx_parser::parse_matmul);
93
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
94
95
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
96
97
98
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("Slice", &onnx_parser::parse_slice);
Scott Thornton's avatar
Scott Thornton committed
99
        add_mem_op("Concat", &onnx_parser::parse_concat);
100
101
102
        add_mem_op("Gather", &onnx_parser::parse_gather);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
103
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
Khalique's avatar
Khalique committed
104
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
Shucai Xiao's avatar
Shucai Xiao committed
105
        add_mem_op("RNN", &onnx_parser::parse_rnn);
106
        add_mem_op("GRU", &onnx_parser::parse_gru);
Shucai Xiao's avatar
Shucai Xiao committed
107
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
Khalique's avatar
Khalique committed
108
        add_mem_op("Pad", &onnx_parser::parse_pad);
Shucai Xiao's avatar
Shucai Xiao committed
109
110
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
111
112
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
113
114
115
116
117
118
119

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
120
121
122
123
124
125
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
126
127
128
129
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
130
131
132
133
134
135
136
137
138
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
139
140
141
142
143
144
145
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
146
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
147
148
149
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
150

151
    template <class T>
Khalique's avatar
Khalique committed
152
    void add_binary_op(std::string name, T x)
153
    {
Paul's avatar
Paul committed
154
        add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
155
            if(args.size() != 2)
Paul's avatar
Paul committed
156
                MIGRAPHX_THROW("binary operators should have 2 operands");
157
            if(contains(attributes, "broadcast") and contains(attributes, "axis"))
158
159
160
161
            {
                uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
                if(broadcasted != 0)
                {
162
                    uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
163
164
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
165
166
                    return prog.add_instruction(x, args[0], l);
                }
167
                return prog.add_instruction(x, args);
168
            }
Paul's avatar
Paul committed
169
            else
170
            {
Khalique's avatar
Khalique committed
171
                return add_broadcastable_binary_op(args[0], args[1], x);
172
173
174
175
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
176
177
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
178
179
180
181
182
183
184
185
186
187
188
189
190
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
191
        if(s0.size() > s1.size())
192
193
194
195
196
197
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
198
199
200
201
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
202
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
203
                           if(a != b and a != 1 and b != 1)
204
                           {
Shucai Xiao's avatar
Shucai Xiao committed
205
206
207
208
209
210
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
211
212
213
214

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
215
216
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
217
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
218
219
220
221
222
223
224
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
225
226
227
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
228
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
229
230
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
231
232
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
233
            auto out_lens = compute_broadcasted_lens(s0, s1);
234
235
236
237
238
239
240
241
242

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
243
244
245
246
247
248
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
249
250
    }

Paul's avatar
Paul committed
251
    template <class T>
Paul's avatar
Paul committed
252
253
    void add_generic_op(std::string name, T x)
    {
Paul's avatar
Paul committed
254
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
255
256
257
258
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
259
    template <class T>
Khalique's avatar
Khalique committed
260
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
261
    {
Paul's avatar
Paul committed
262
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
263
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
264
265
266
267
268
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
269
        });
Khalique's avatar
Khalique committed
270
271
    }

Khalique's avatar
Khalique committed
272
273
274
    instruction_ref parse_clip(const std::string&,
                               const attribute_map& attributes,
                               std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
275
276
277
278
279
280
281
282
283
284
285
286
287
    {
        op::clip op;
        if(contains(attributes, "max"))
        {
            op.max_val = parse_value(attributes.at("max")).at<float>();
        }
        if(contains(attributes, "min"))
        {
            op.min_val = parse_value(attributes.at("min")).at<float>();
        }
        return prog.add_instruction(op, std::move(args));
    }

Shucai Xiao's avatar
Shucai Xiao committed
288
    template <class Op>
289
    instruction_ref parse_softmax(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
290
291
                                  const attribute_map& attributes,
                                  std::vector<instruction_ref> args)
Paul's avatar
Paul committed
292
    {
293
        int64_t axis = 1;
294
295
296
297
298
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

299
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
300
301
    }

Shucai Xiao's avatar
Shucai Xiao committed
302
    template <class Op>
303
    instruction_ref parse_arg_op(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
304
305
                                 const attribute_map& attributes,
                                 std::vector<instruction_ref> args)
306
    {
307
        int64_t axis = 0;
308
309
        if(contains(attributes, "axis"))
        {
310
            axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
311
312
        }

Shucai Xiao's avatar
Shucai Xiao committed
313
        int keep_dims = 1;
Shucai Xiao's avatar
Shucai Xiao committed
314
        if(contains(attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
315
316
317
318
        {
            keep_dims = parse_value(attributes.at("keepdims")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
319
        if(keep_dims == 0)
320
        {
321
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
322
            return prog.add_instruction(op::squeeze{{axis}}, ins);
323
324
325
        }
        else
        {
326
            return prog.add_instruction(Op{axis}, std::move(args));
327
        }
328
329
    }

Paul's avatar
Paul committed
330
    instruction_ref
Paul's avatar
Paul committed
331
    parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
332
    {
333
        op::convolution op;
334
        auto l0 = args[0];
Paul's avatar
Paul committed
335
336
        if(contains(attributes, "pads"))
        {
Scott Thornton's avatar
Scott Thornton committed
337
            if(contains(attributes, "auto_pad"))
338
            {
339
340
341
342
343
                auto s = attributes["auto_pad"].s();
                if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
344
            }
345
346
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
347
            if(padding.size() != 4)
348
            {
Paul's avatar
Paul committed
349
                MIGRAPHX_THROW("padding should have 4 values");
350
            }
Scott Thornton's avatar
Scott Thornton committed
351
            if(padding[0] != padding[2] || padding[1] != padding[3])
352
            {
353
354
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
Khalique's avatar
Khalique committed
355
                l0      = prog.add_instruction(op::pad{padding}, l0);
356
            }
357
358
359
360
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
361
            }
Paul's avatar
Paul committed
362
        }
Paul's avatar
Paul committed
363
364
365
366
367
368
369
370
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "dilations"))
        {
            copy(attributes["dilations"].ints(), op.dilation.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
371
        if(contains(attributes, "auto_pad"))
372
373
        {
            auto s = attributes["auto_pad"].s();
Scott Thornton's avatar
Scott Thornton committed
374
            if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
375
            {
Paul's avatar
Paul committed
376
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
377
378
            }

wsttiger's avatar
fixes  
wsttiger committed
379
            if(s.find("SAME") != std::string::npos)
380
            {
381
                op.padding_mode = op::padding_mode_t::same;
382
383
            }
        }
Khalique's avatar
Khalique committed
384
385
386
387
        if(contains(attributes, "group"))
        {
            op.group = parse_value(attributes.at("group")).at<int>();
        }
Paul's avatar
Paul committed
388
389
390
        if(args.size() == 3)
        {
            uint64_t axis = 1;
Khalique's avatar
Khalique committed
391
            auto l1       = prog.add_instruction(op, l0, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
392
            auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
393
            return prog.add_instruction(op::add{}, l1, l2);
Paul's avatar
Paul committed
394
        }
395
        return prog.add_instruction(op, l0, args[1]);
Paul's avatar
Paul committed
396
    }
Paul's avatar
Paul committed
397

Paul's avatar
Paul committed
398
399
400
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
                                  std::vector<instruction_ref> args)
Paul's avatar
Paul committed
401
    {
Khalique's avatar
Khalique committed
402
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
403
        auto l0 = args[0];
Khalique's avatar
Khalique committed
404
        if(starts_with(name, "Global"))
405
        {
Khalique's avatar
Khalique committed
406
407
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
408
        }
Paul's avatar
Paul committed
409
410
        if(contains(attributes, "pads"))
        {
411
412
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
413
            if(padding.size() != 4)
414
            {
Paul's avatar
Paul committed
415
                MIGRAPHX_THROW("padding should have 4 values");
416
            }
Scott Thornton's avatar
Scott Thornton committed
417
            if(padding[0] != padding[2] || padding[1] != padding[3])
418
            {
419
420
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
Khalique's avatar
Khalique committed
421
422
                l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
                                          l0);
423
424
425
426
427
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
428
            }
Paul's avatar
Paul committed
429
430
431
432
433
434
435
436
437
        }
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "kernel_shape"))
        {
            copy(attributes["kernel_shape"].ints(), op.lengths.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
438
        if(contains(attributes, "auto_pad"))
439
440
        {
            auto s = attributes["auto_pad"].s();
441
            if(s.find("SAME_UPPER") == std::string::npos)
442
            {
443
                MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling");
444
            }
445
            op.padding_mode = op::padding_mode_t::same;
446
447
        }

448
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
449
450
    }

Paul's avatar
Paul committed
451
    instruction_ref
Paul's avatar
Paul committed
452
    parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
453
    {
454
        op::reshape op;
Paul's avatar
Paul committed
455
456
        if(args.size() == 1)
        {
457
458
            literal s = parse_value(attributes.at("shape"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
459
460
461
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
462
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
463
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
464
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
465
        }
466

Shucai Xiao's avatar
Shucai Xiao committed
467
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
468
469
    }

Paul's avatar
Paul committed
470
    instruction_ref
Paul's avatar
Paul committed
471
    parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
472
    {
473
        int64_t axis = 1;
Paul's avatar
Paul committed
474
475
476
477
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
478
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
479
480
    }

481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    instruction_ref
    parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::squeeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

    instruction_ref
    parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::unsqueeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

Scott Thornton's avatar
Scott Thornton committed
499
500
501
    instruction_ref
    parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
Shucai Xiao's avatar
Shucai Xiao committed
502
503
504
505
506
507
508
        // change to hande axis to be negative values
        if(!contains(attributes, "axis"))
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

        int axis = parse_value(attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
509
510
511
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
512

513
514
515
    instruction_ref
    parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
516
        int axis = 0;
517
518
519
520
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
521

522
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
523
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
524
525
    }

526
527
528
529
    instruction_ref
    parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::slice op;
Khalique's avatar
Khalique committed
530
        std::vector<size_t> dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
531
        size_t num_dims          = dims.size();
532
533
534
535
536
        if(contains(attributes, "axes"))
        {
            literal s = parse_value(attributes.at("axes"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Khalique's avatar
Khalique committed
537
538
539
540
541
        else
        {
            op.axes = std::vector<int64_t>(num_dims);
            std::iota(op.axes.begin(), op.axes.end(), 0);
        }
Khalique's avatar
Khalique committed
542

Khalique's avatar
Khalique committed
543
        if(contains(attributes, "ends"))
544
        {
Paul's avatar
Paul committed
545
            op.ends = get_indices(attributes.at("ends"));
546
        }
Khalique's avatar
Khalique committed
547
        if(contains(attributes, "starts"))
548
549
550
551
552
553
554
        {
            literal s = parse_value(attributes.at("starts"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
555
556
557
    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
                                   const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
558
    {
Shucai Xiao's avatar
Shucai Xiao committed
559
        literal v = parse_value(attributes.at("value"));
560
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
561
        if(v.get_shape().elements() == 0)
562
563
564
565
        {
            return prog.add_literal(literal{});
        }

566
567
568
        auto dim_size = attributes.at("value").t().dims_size();
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
569
        {
570
            migraphx::shape scalar_shape{v.get_shape().type()};
571
572
573
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
574
575
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
576

Paul's avatar
Paul committed
577
    instruction_ref
Paul's avatar
Paul committed
578
    parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
579
580
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
581
        float beta  = 1.0f;
Paul's avatar
Paul committed
582
583
584
585
586
587
588
589
        bool transa = false;
        bool transb = false;
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        if(contains(attributes, "beta"))
        {
590
            beta = parse_value(attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
591
592
593
594
595
596
597
598
599
        }
        if(contains(attributes, "transA"))
        {
            transa = parse_value(attributes.at("transA")).at<bool>();
        }
        if(contains(attributes, "transB"))
        {
            transb = parse_value(attributes.at("transB")).at<bool>();
        }
600
601
602
603
604
605

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

606
607
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
608
609
        if(args.size() == 3)
        {
610
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
611
            {
Shucai Xiao's avatar
Shucai Xiao committed
612
                auto out_lens   = l1->get_shape().lens();
613
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
614
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
615
616
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
617
                {
618
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
619
                }
620
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
621
            }
Paul's avatar
Paul committed
622
        }
623
624

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
625
626
    }

627
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
628
    parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
629
    {
Shucai Xiao's avatar
Shucai Xiao committed
630
631
        auto l0      = args[0];
        auto l1      = args[1];
632
633
634
635
636
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
637
        if(l0_lens.size() == 1)
638
639
640
641
642
643
644
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
645
        if(l1_lens.size() == 1)
646
647
648
649
650
651
652
653
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
654
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
655
656
657
658
659
660
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
661
            l0_broadcasted_lens = output_lens;
662
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
663
            l1_broadcasted_lens = output_lens;
664
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
665
            if(l0_lens != l0_broadcasted_lens)
666
667
668
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
669
            if(l1_lens != l1_broadcasted_lens)
670
671
672
673
674
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
675
        auto dot_res     = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
676
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
677
        if(is_a_prepended)
678
679
680
681
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
682
        if(is_b_appended)
683
684
685
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
686

687
688
689
        return dot_res;
    }

690
    instruction_ref
Paul's avatar
Paul committed
691
    parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
692
    {
Scott Thornton's avatar
Scott Thornton committed
693
694
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
695
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
696
697
698
699
700
701
        if(contains(attributes, "epsilon"))
        {
            epsilon = parse_value(attributes.at("epsilon")).at<float>();
        }
        if(contains(attributes, "momentum"))
        {
702
            momentum = parse_value(attributes.at("momentum")).at<float>();
703
704
705
        }
        if(contains(attributes, "spatial"))
        {
706
            bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
707
708
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
709
        }
Paul's avatar
Paul committed
710
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
711
        return prog.add_instruction(op, std::move(args));
712
713
    }

714
715
716
717
    instruction_ref parse_leaky_relu(const std::string&,
                                     attribute_map attributes,
                                     std::vector<instruction_ref> args)
    {
Khalique's avatar
Khalique committed
718
        float alpha = 0.01; // default alpha val for leaky relu
719
720
721
722
723
724
725
726
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
727
728
    instruction_ref
    parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
729
730
731
732
733
734
735
736
737
738
    {
        float alpha = 1.0; // default alpha val for elu
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
739
740
    instruction_ref
    parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
741
742
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
743
744
745
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
Khalique's avatar
Khalique committed
746
747
748
749
750
751
752
753
754
755
756
757
        if(contains(attributes, "alpha"))
            alpha = parse_value(attributes.at("alpha")).at<float>();
        if(contains(attributes, "beta"))
            beta = parse_value(attributes.at("beta")).at<float>();
        if(contains(attributes, "bias"))
            bias = parse_value(attributes.at("bias")).at<float>();
        if(contains(attributes, "size"))
            size = parse_value(attributes.at("size")).at<int>();
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    instruction_ref parse_imagescaler(const std::string&,
                                      attribute_map attributes,
                                      std::vector<instruction_ref> args)
    {
        float scale = 1.0;
        std::vector<float> bias{};
        if(contains(attributes, "scale"))
        {
            scale = parse_value(attributes.at("scale")).at<float>();
        }

        if(contains(attributes, "bias"))
        {
            auto&& bias_floats = attributes["bias"].floats();
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
774
        auto input_lens = args.front()->get_shape().lens();
Khalique's avatar
Khalique committed
775

Khalique's avatar
Khalique committed
776
777
        auto scale_val = prog.add_literal(scale);
        auto bias_vals = prog.add_literal(
Paul's avatar
Paul committed
778
            migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
779

780
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
781
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
782
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
783
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
784
    }
Khalique's avatar
Khalique committed
785

Khalique's avatar
Khalique committed
786
787
    instruction_ref
    parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
788
789
790
791
792
793
794
    {
        std::vector<int64_t> perm{};
        if(contains(attributes, "perm"))
        {
            auto&& perm_vals = attributes["perm"].ints();
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
795
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
796
797
    }

Khalique's avatar
Khalique committed
798
799
800
801
802
803
804
805
806
807
    instruction_ref
    parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
        if(contains(attributes, "pads"))
        {
            auto&& pad_vals = attributes["pads"].ints();
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
808
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
809
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
810
811
812
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
Khalique's avatar
Khalique committed
813
814
815
816
817
818
819
820
821
822
823
824
        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }
        if(contains(attributes, "mode"))
        {
            auto mode = attributes.at("mode").s();
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
825
826
827
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
828
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
829
830
    {
        if(args.size() != 1)
831
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
    instruction_ref parse_constant_fill(const std::string&,
                                        attribute_map attributes,
                                        std::vector<instruction_ref> args)
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

        if(contains(attributes, "dtype"))
        {
            dtype = parse_value(attributes.at("dtype")).at<int>();
        }
Shucai Xiao's avatar
Shucai Xiao committed
856
        shape::type_t type = get_type(dtype);
857
858
859
860
861
862
863
864
865
866
867

        if(contains(attributes, "input_as_shape"))
        {
            input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
        }

        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
868
869
        if(contains(attributes, "extra_shape"))
        {
870
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
871
872
        }

873
874
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
875
            if(args.size() != 1)
876
            {
877
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
878
879
            }

Shucai Xiao's avatar
Shucai Xiao committed
880
881
            if(contains(attributes, "shape"))
            {
882
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
883
                               "at the same time");
884
885
            }

886
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
887
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
888

889
890
891
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
892
893
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
894
895
896
        }
        else if(input_as_shape == 0)
        {
Shucai Xiao's avatar
Shucai Xiao committed
897
898
            if(!contains(attributes, "shape"))
            {
899
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
900
901
902
            }

            literal ls = parse_value(attributes.at("shape"));
903
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
904
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
905
            migraphx::shape s{type, dims};
906
907
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
908
909
910
        }
        else
        {
911
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
912
913
914
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
915
916
917
    instruction_ref parse_constant_of_shape(const std::string&,
                                            attribute_map attributes,
                                            std::vector<instruction_ref> args)
918
919
    {
        literal l_val{};
Shucai Xiao's avatar
Shucai Xiao committed
920
        if(contains(attributes, "value"))
921
922
        {
            l_val = parse_value(attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
923
            if(l_val.get_shape().elements() != 1)
924
925
926
927
928
929
930
931
932
933
934
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
935

Shucai Xiao's avatar
Shucai Xiao committed
936
        if(args.empty())
937
        {
Shucai Xiao's avatar
Shucai Xiao committed
938
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
939
940
941
        }
        else
        {
942
943
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
944
            if(args[0]->get_shape().elements() == 0)
945
            {
946
                s = migraphx::shape{type, {1}, {0}};
947
            }
948
949
950
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
951
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
952

953
954
955
956
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
957

Shucai Xiao's avatar
Shucai Xiao committed
958
            literal l_out{};
959
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
960
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
961
                // l_val contains only one element
962
                std::vector<val_type> out_vec(s.elements(), val.front());
963
964
965
966
967
968
969
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
970
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
971
    parse_expand(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
972
    {
Shucai Xiao's avatar
Shucai Xiao committed
973
        auto in_lens             = args[0]->get_shape().lens();
974
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
975
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
976
977
978
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
979
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
980
981
    }

Shucai Xiao's avatar
Shucai Xiao committed
982
    std::vector<instruction_ref>
Shucai Xiao's avatar
Shucai Xiao committed
983
984
985
    parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
986
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
987
988
989

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
990
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
991
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
992
993
994
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
995
996
997
998
999
1000
1001
1002
1003
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

1004
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1005
1006
        if(direction == "bidirectional")
        {
1007
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1008
1009
1010
        }
        else if(direction == "reverse")
        {
1011
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1012
1013
        }

1014
        std::vector<std::string> vec_names{"tanh"};
1015
1016
1017
1018
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
1019
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1020
1021
1022
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1023
1024
        }

1025
1026
1027
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1028
        if(name_it != vec_names.end())
1029
1030
1031
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1032

Shucai Xiao's avatar
Shucai Xiao committed
1033
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1034
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1035
        // if only one actv function is provided, we use it in both
1036
        // forward and reverse direction
1037
        if(dirct == op::rnn_direction::bidirectional)
1038
        {
Shucai Xiao's avatar
Shucai Xiao committed
1039
            if(vec_names.size() == 1)
1040
1041
1042
1043
1044
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1045
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1046
1047
1048
1049
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1050

Shucai Xiao's avatar
Shucai Xiao committed
1051
1052
1053
1054
1055
1056
1057
        // To be added later
        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

1058
1059
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1060
        if(args.size() < 6)
1061
1062
1063
1064
1065
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1066
1067
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1068
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1069

1070
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1071
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1072

Shucai Xiao's avatar
Shucai Xiao committed
1073
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1074
1075
    }

1076
    std::vector<instruction_ref>
1077
1078
1079
1080
1081
1082
1083
    parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
1084
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1085
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1086
1087
1088
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1089
1090
1091
1092
1093
1094
1095
1096
1097
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

1098
        op::rnn_direction dirct = op::rnn_direction::forward;
1099
1100
        if(direction == "bidirectional")
        {
1101
            dirct = op::rnn_direction::bidirectional;
1102
1103
1104
        }
        else if(direction == "reverse")
        {
1105
            dirct = op::rnn_direction::reverse;
1106
1107
        }

1108
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1109
1110
        if(contains(attributes, "activations"))
        {
1111
            auto names = attributes.at("activations").strings();
1112
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1113
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1114
1115
1116
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1117
1118
        }

1119
        // need 4 activation functions
1120
        if(dirct == op::rnn_direction::bidirectional)
1121
        {
Shucai Xiao's avatar
Shucai Xiao committed
1122
            // 4 activation functions are used in the bidirectional
1123
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1124
1125
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1126
1127
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1128
1129
1130
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1131
            if(vec_names.size() == 1)
1132
            {
1133
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1134
            }
1135
            else if(vec_names.size() == 2)
1136
            {
1137
1138
1139
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1140
            }
1141
            else if(vec_names.size() == 3)
1142
            {
1143
                vec_names.push_back(vec_names.at(2));
1144
1145
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1146
        else
1147
        {
1148
            if(vec_names.size() == 1)
1149
            {
1150
                vec_names.push_back(vec_names.at(0));
1151
1152
1153
            }
        }

1154
1155
1156
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1157
        if(name_it != vec_names.end())
1158
1159
1160
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1161

Shucai Xiao's avatar
Shucai Xiao committed
1162
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1163
1164
1165
1166
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1167
1168
1169
1170
1171
1172
1173
1174

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int linear_before_reset = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1175
        if(contains(attributes, "linear_before_reset"))
1176
1177
1178
1179
        {
            linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1180
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1181
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1182
1183
1184
1185
1186
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1187
1188
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1189
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1190
            std::move(args));
1191
1192

        // second output for last gru output
1193
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1194

Shucai Xiao's avatar
Shucai Xiao committed
1195
        return {hidden_states, last_output};
1196
1197
    }

Shucai Xiao's avatar
Shucai Xiao committed
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
    std::vector<instruction_ref>
    parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1220
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1221
1222
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1223
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1224
1225
1226
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1227
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1228
        }
Shucai Xiao's avatar
Shucai Xiao committed
1229
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1230
        {
Shucai Xiao's avatar
Shucai Xiao committed
1231
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1232
1233
1234
1235
1236
1237
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1238
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
Shucai Xiao's avatar
Shucai Xiao committed
1239
1240
1241
1242
1243
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1244
1245
1246
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1247
1248
1249
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1250
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1251
1252
1253
1254
1255
1256
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1257
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1258
1259
1260
1261
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1262
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1263
1264
1265
1266
1267
1268
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1269
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1270
1271
1272

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1273
1274
1275
1276
1277
1278
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1279
1280
1281
1282
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1283
1284
1285
1286
1287
1288
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1289
1290
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1291
1292
1293
1294
1295
1296
1297
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1298
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1299

Shucai Xiao's avatar
Shucai Xiao committed
1300
1301
1302
1303
1304
1305
1306
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1307
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1308

Shucai Xiao's avatar
Shucai Xiao committed
1309
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1310
1311
1312
1313
1314
1315
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1316
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1317
1318
1319

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1320
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1321
1322
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1323
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1324
1325
1326
            }
        }

1327
1328
1329
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1330
        if(name_it != vec_names.end())
1331
1332
1333
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1334
1335

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1336
1337
1338
1339
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int input_forget = 0;
        if(contains(attributes, "input_forget"))
        {
            input_forget = parse_value(attributes.at("input_forget")).at<int>();
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1357
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1358
1359
1360
1361
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1362
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1363
1364

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1365
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1366
1367
1368
1369
1370
1371

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1372

Shucai Xiao's avatar
Shucai Xiao committed
1373
    template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
1374
    instruction_ref parse_reduce_oper(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
1375
1376
                                      attribute_map attributes,
                                      std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1377
1378
1379
1380
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1381
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1382
1383
1384
1385
1386
        std::iota(axes.begin(), axes.end(), 0);
        if(contains(attributes, "axes"))
        {
            axes.clear();
            auto&& attr_axes = attributes["axes"].ints();
1387
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
        }

        int keep_dims = 1;
        if(contains(attributes, "keepdims"))
        {
            keep_dims = parse_value(attributes.at("keepdims")).at<int>();
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1398
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1399
1400
1401
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1402
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1403
            return prog.add_instruction(op::squeeze{axes}, ins);
1404
1405
        }
    }
1406

Shucai Xiao's avatar
Shucai Xiao committed
1407
1408
    instruction_ref
    parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
1409
    {
Shucai Xiao's avatar
Shucai Xiao committed
1410
        if(!contains(attributes, "to"))
1411
1412
1413
1414
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

Shucai Xiao's avatar
Shucai Xiao committed
1415
        int to_type        = parse_value(attributes.at("to")).at<int>();
1416
1417
1418
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1419

Paul's avatar
Paul committed
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1432
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1433
1434
1435
1436
1437
1438
        }
    }

    void parse_graph(const onnx::GraphProto& graph)
    {
        nodes = get_nodes(graph);
1439
        for(auto&& f : graph.initializer())
1440
1441
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1442
1443
1444
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1445
1446
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1447
1448
            {
                // TODO: Get shape of input parameter
1449
                shape s            = parse_type(input.type(), batch_size);
1450
1451
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1452
        }
Paul's avatar
Paul committed
1453
        for(auto&& output : graph.output())
Paul's avatar
Paul committed
1454
        {
Paul's avatar
Paul committed
1455
            this->parse_node(output.name());
Paul's avatar
Paul committed
1456
        }
Shucai Xiao's avatar
Shucai Xiao committed
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469

        // For now, the last output with a valid name is considered
        // as the program output, and add an identity instruction at
        // the program end
        auto prog_output = graph.output();
        auto oit         = std::find_if(prog_output.rbegin(), prog_output.rend(), [](auto& node) {
            return !node.name().empty();
        });

        if(instructions.count(oit->name()) > 0)
        {
            prog.add_instruction(op::identity{}, instructions[oit->name()]);
        }
Paul's avatar
Paul committed
1470
1471
    }

Shucai Xiao's avatar
Shucai Xiao committed
1472
    void parse_undefined(const std::string& name)
1473
    {
Shucai Xiao's avatar
Shucai Xiao committed
1474
        auto ins           = prog.add_instruction(op::undefined{});
1475
1476
1477
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1478
    void parse_node(const std::string& name)
Paul's avatar
Paul committed
1479
    {
Paul's avatar
Paul committed
1480
        if(name.empty())
Paul's avatar
Paul committed
1481
            MIGRAPHX_THROW("Onnx node must have a name");
Paul's avatar
Paul committed
1482
1483
1484
1485
1486
1487
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1488
                if(input.empty())
Paul's avatar
Paul committed
1489
                {
Shucai Xiao's avatar
Shucai Xiao committed
1490
                    this->parse_undefined(input);
Paul's avatar
Paul committed
1491
                }
Shucai Xiao's avatar
Shucai Xiao committed
1492
                else if(nodes.count(input) > 0)
Paul's avatar
Paul committed
1493
                {
Shucai Xiao's avatar
Shucai Xiao committed
1494
1495
                    assert(name != input);
                    this->parse_node(input);
Paul's avatar
Paul committed
1496
                }
1497
                args.push_back(instructions.at(input));
Paul's avatar
Paul committed
1498
            }
Paul's avatar
Paul committed
1499
            std::vector<instruction_ref> result;
Paul's avatar
Paul committed
1500
1501
            if(ops.count(node.op_type()) == 0)
            {
1502
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
Paul's avatar
Paul committed
1503
1504
1505
            }
            else
            {
Paul's avatar
Paul committed
1506
                result = ops[node.op_type()](get_attributes(node), args);
Paul's avatar
Paul committed
1507
            }
Paul's avatar
Paul committed
1508
            // Even no output nodes produce output in migraphx
Paul's avatar
Paul committed
1509
            if(node.output().empty() and result.size() == 1)
Paul's avatar
Paul committed
1510
1511
            {
                instructions[name] = result.front();
Paul's avatar
Paul committed
1512
1513
1514
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1515
1516
1517
1518
                assert(node.output().size() <= result.size());
                std::transform(node.output().begin(),
                               node.output().end(),
                               result.begin(),
Paul's avatar
Paul committed
1519
                               std::inserter(instructions, instructions.end()),
Shucai Xiao's avatar
Shucai Xiao committed
1520
                               [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
            }
        }
    }

    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    static node_map get_nodes(const onnx::GraphProto& graph)
    {
        std::unordered_map<std::string, onnx::NodeProto> result;
Paul's avatar
Paul committed
1538
        std::size_t n = 0;
Paul's avatar
Paul committed
1539
1540
        for(auto&& node : graph.node())
        {
Paul's avatar
Paul committed
1541
            if(node.output().empty())
Paul's avatar
Paul committed
1542
            {
Paul's avatar
Paul committed
1543
                if(node.name().empty())
Paul's avatar
Paul committed
1544
1545
1546
1547
1548
1549
1550
1551
1552
                {
                    result["migraphx_unamed_node_" + std::to_string(n)] = node;
                    n++;
                }
                else
                {
                    result[node.name()] = node;
                }
            }
Paul's avatar
Paul committed
1553
1554
1555
1556
1557
1558
1559
1560
            for(auto&& output : node.output())
            {
                result[output] = node;
            }
        }
        return result;
    }

Paul's avatar
Paul committed
1561
1562
1563
1564
1565
1566
    static std::vector<int64_t> get_indices(const onnx::AttributeProto& attr)
    {
        std::vector<int64_t> result;
        literal s = parse_value(attr);
        s.visit([&](auto v) { copy(v, std::back_inserter(result)); });
        // Clamp large indices to -1
Paul's avatar
Paul committed
1567
1568
1569
1570
1571
        std::replace_if(
            result.begin(),
            result.end(),
            [](auto x) { return x > int64_t{std::numeric_limits<std::int32_t>::max()} / 2; },
            -1);
Paul's avatar
Paul committed
1572
1573
1574
        return result;
    }

Paul's avatar
Paul committed
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
1589
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1590
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
1591
1592
1593
1594
1595
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
Paul's avatar
Paul committed
1596
1597
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
1598
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
1599
1600
1601
1602
1603
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
1604
1605
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
1606
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
1607
1608
            switch(t.data_type())
            {
1609
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
1610
1611
1612
1613
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
1614
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
1615
1616
1617
1618
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
1619
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
1620
1621
1622
1623
1624
1625
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
1626
1627
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
1628
            MIGRAPHX_THROW("Invalid tensor type");
1629
        }
Paul's avatar
Paul committed
1630
1631
1632
1633
1634
1635
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
1636
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
1637
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1638
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
1639
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
1640
1641
1642
1643
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
1644
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
1645
        {
Khalique's avatar
Khalique committed
1646
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
1647
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
1648
1649
1650
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
1651
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1652
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1653
        }
Paul's avatar
Paul committed
1654
1655
1656
1657
1658
1659
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
1660
1661
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
1662
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
1663
1664
    }

Khalique's avatar
Khalique committed
1665
    static literal
1666
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
1667
    {
Khalique's avatar
Khalique committed
1668
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
1669
        if(dims.empty())
1670
            return literal{{shape_type}, data};
1671
1672
1673
        return literal{{shape_type, dims}, data};
    }

1674
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
1675
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
1676
1677
    {
        if(dims.empty())
1678
            return literal{{shape_type}, data.begin(), data.end()};
1679
        return literal{{shape_type, dims}, data.begin(), data.end()};
1680
1681
    }

1682
    static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
Paul's avatar
Paul committed
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
1693
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
1694
1695
1696
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
1697
1698
1699
1700
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
1701
1702
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
1703
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
1704
1705
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
1706
        auto&& tensor_dims = t.tensor_type().shape().dim();
1707
1708
1709
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
1710
1711
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
1712
                           {
1713
1714
1715
                               if(static_cast<int>(d.dim_value()) <= 0)
                                   return batch_size;
                               return d.dim_value();
1716
                           }
1717
                           return batch_size;
1718
                       });
1719
1720
1721
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
1722
1723
        return {shape_type, dims};
    }
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
1746
1747
1748

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
1749
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1750
1751
1752
1753
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
1754
1755
};

1756
program parse_onnx(const std::string& name, onnx_options options)
Paul's avatar
Paul committed
1757
1758
1759
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    onnx_parser parser;
1760
    parser.batch_size = options.batch_size;
Paul's avatar
Paul committed
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
    return std::move(parser.prog);
}

Paul's avatar
Paul committed
1778
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1779
} // namespace migraphx