onnx.cpp 66.8 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
Paul's avatar
Paul committed
19
20

namespace migraphx {
Paul's avatar
Paul committed
21
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
22
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
    using node_map      = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
27
28
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
29
30
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
31
32
33
    program prog            = program();
    bool is_pytorch         = false;
    unsigned int batch_size = 1;
Paul's avatar
Paul committed
34
35

    std::unordered_map<std::string, op_func> ops;
36
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
37
38
39

    onnx_parser()
    {
Khalique's avatar
Khalique committed
40
        add_generic_op("Relu", op::relu{});
Khalique's avatar
Khalique committed
41
42
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Abs", op::abs{});
Shucai Xiao's avatar
Shucai Xiao committed
43
        add_generic_op("Exp", op::exp{});
Shucai Xiao's avatar
Shucai Xiao committed
44
        add_generic_op("Erf", op::erf{});
Shucai Xiao's avatar
Shucai Xiao committed
45
        add_generic_op("Log", op::log{});
Khalique's avatar
Khalique committed
46
47
        // disable dropout for inference
        add_generic_op("Dropout", op::identity{});
Khalique's avatar
Khalique committed
48
        add_generic_op("Identity", op::identity{});
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
        add_generic_op("Sin", op::sin{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Tan", op::tan{});
52
53
        add_generic_op("Sinh", op::sinh{});
        add_generic_op("Cosh", op::cosh{});
54
        add_generic_op("Tanh", op::tanh{});
55
56
57
        add_generic_op("Asin", op::asin{});
        add_generic_op("Acos", op::acos{});
        add_generic_op("Atan", op::atan{});
58
        add_generic_op("Sqrt", op::sqrt{});
59
        add_generic_op("Round", op::round{});
60
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
61
62
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Floor", op::floor{});
Paul's avatar
Paul committed
63

Khalique's avatar
Khalique committed
64
65
66
67
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
        add_binary_op("Sub", op::sub{});
Shucai Xiao's avatar
Shucai Xiao committed
68
        add_binary_op("Pow", op::pow{});
Khalique's avatar
Khalique committed
69

Khalique's avatar
Khalique committed
70
71
72
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
73

74
75
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
76
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
77
        add_mem_op("Clip", &onnx_parser::parse_clip);
Khalique's avatar
Khalique committed
78
        add_mem_op("LRN", &onnx_parser::parse_lrn);
Khalique's avatar
Khalique committed
79
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
80
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
Khalique's avatar
Khalique committed
81
        add_mem_op("Elu", &onnx_parser::parse_elu);
82
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
83
84
        add_mem_op("Constant", &onnx_parser::parse_constant);
        add_mem_op("Conv", &onnx_parser::parse_conv);
Paul's avatar
Paul committed
85
86
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
87
88
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
Paul's avatar
Paul committed
89
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
Paul's avatar
Paul committed
90
91
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
92
        add_mem_op("MatMul", &onnx_parser::parse_matmul);
93
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
94
95
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
96
97
98
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("Slice", &onnx_parser::parse_slice);
Scott Thornton's avatar
Scott Thornton committed
99
        add_mem_op("Concat", &onnx_parser::parse_concat);
100
101
102
        add_mem_op("Gather", &onnx_parser::parse_gather);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
103
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
Khalique's avatar
Khalique committed
104
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
Shucai Xiao's avatar
Shucai Xiao committed
105
        add_mem_op("RNN", &onnx_parser::parse_rnn);
106
        add_mem_op("GRU", &onnx_parser::parse_gru);
Shucai Xiao's avatar
Shucai Xiao committed
107
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
Khalique's avatar
Khalique committed
108
        add_mem_op("Pad", &onnx_parser::parse_pad);
Shucai Xiao's avatar
Shucai Xiao committed
109
110
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
111
112
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
113
114
115
116
117
118
119

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
120
121
122
123
124
125
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
126
127
128
129
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
130
131
132
133
134
135
136
137
138
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
139
140
141
142
143
144
145
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
146
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
147
148
149
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
150

151
    template <class T>
Khalique's avatar
Khalique committed
152
    void add_binary_op(std::string name, T x)
153
    {
Paul's avatar
Paul committed
154
        add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
155
            if(args.size() != 2)
Paul's avatar
Paul committed
156
                MIGRAPHX_THROW("binary operators should have 2 operands");
157
            if(contains(attributes, "broadcast") and contains(attributes, "axis"))
158
159
160
161
            {
                uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
                if(broadcasted != 0)
                {
162
                    uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
163
164
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
165
166
                    return prog.add_instruction(x, args[0], l);
                }
167
                return prog.add_instruction(x, args);
168
            }
Paul's avatar
Paul committed
169
            else
170
            {
Khalique's avatar
Khalique committed
171
                return add_broadcastable_binary_op(args[0], args[1], x);
172
173
174
175
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
176
177
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
178
179
180
181
182
183
184
185
186
187
188
189
190
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
191
        if(s0.size() > s1.size())
192
193
194
195
196
197
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
198
199
200
201
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
202
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
203
                           if(a != b and a != 1 and b != 1)
204
                           {
Shucai Xiao's avatar
Shucai Xiao committed
205
206
207
208
209
210
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
211
212
213
214

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
215
216
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
217
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
218
219
220
221
222
223
224
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
225
226
227
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
228
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
229
230
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
231
232
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
233
            auto out_lens = compute_broadcasted_lens(s0, s1);
234
235
236
237
238
239
240
241
242

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
243
244
245
246
247
248
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
249
250
    }

Paul's avatar
Paul committed
251
    template <class T>
Paul's avatar
Paul committed
252
253
    void add_generic_op(std::string name, T x)
    {
Paul's avatar
Paul committed
254
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
255
256
257
258
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
259
    template <class T>
Khalique's avatar
Khalique committed
260
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
261
    {
Paul's avatar
Paul committed
262
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
263
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
264
265
266
267
268
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
269
        });
Khalique's avatar
Khalique committed
270
271
    }

Khalique's avatar
Khalique committed
272
273
274
    instruction_ref parse_clip(const std::string&,
                               const attribute_map& attributes,
                               std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
275
276
277
278
279
280
281
282
283
284
285
286
287
    {
        op::clip op;
        if(contains(attributes, "max"))
        {
            op.max_val = parse_value(attributes.at("max")).at<float>();
        }
        if(contains(attributes, "min"))
        {
            op.min_val = parse_value(attributes.at("min")).at<float>();
        }
        return prog.add_instruction(op, std::move(args));
    }

Shucai Xiao's avatar
Shucai Xiao committed
288
    template <class Op>
289
    instruction_ref parse_softmax(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
290
291
                                  const attribute_map& attributes,
                                  std::vector<instruction_ref> args)
Paul's avatar
Paul committed
292
    {
293
        int64_t axis = 1;
294
295
296
297
298
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

299
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
300
301
    }

Shucai Xiao's avatar
Shucai Xiao committed
302
    template <class Op>
303
    instruction_ref parse_arg_op(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
304
305
                                 const attribute_map& attributes,
                                 std::vector<instruction_ref> args)
306
    {
307
        int64_t axis = 0;
308
309
        if(contains(attributes, "axis"))
        {
310
            axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
311
312
        }

Shucai Xiao's avatar
Shucai Xiao committed
313
        int keep_dims = 1;
Shucai Xiao's avatar
Shucai Xiao committed
314
        if(contains(attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
315
316
317
318
        {
            keep_dims = parse_value(attributes.at("keepdims")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
319
        if(keep_dims == 0)
320
        {
321
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
322
            return prog.add_instruction(op::squeeze{{axis}}, ins);
323
324
325
        }
        else
        {
326
            return prog.add_instruction(Op{axis}, std::move(args));
327
        }
328
329
    }

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    template <class Op>
    instruction_ref process_auto_pad_attribute(instruction_ref ins,
                                               attribute_map& attributes,
                                               Op& op,
                                               const std::vector<std::size_t>& in_lens)
    {
        if(!contains(attributes, "auto_pad"))
        {
            return ins;
        }

        auto auto_pad = attributes["auto_pad"].s();
        if(auto_pad.find("SAME") != std::string::npos)
        {
            // calculate the padding
            std::array<std::size_t, 2> out_lens;
            out_lens[0] = (in_lens[2] + op.stride[0] - 1) / op.stride[0];
            out_lens[1] = (in_lens[3] + op.stride[1] - 1) / op.stride[1];

            std::array<std::size_t, 2> explicit_pads;
            explicit_pads[0] = (out_lens[0] - 1) * op.stride[0] + op.lengths[0] - in_lens[2];
            explicit_pads[1] = (out_lens[1] - 1) * op.stride[1] + op.lengths[1] - in_lens[3];
            op.padding[0]    = explicit_pads[0] / 2;
            op.padding[1]    = explicit_pads[1] / 2;
            explicit_pads[0] -= 2 * op.padding[0];
            explicit_pads[1] -= 2 * op.padding[1];
            std::vector<std::int64_t> pads(8, 0);
            if(explicit_pads[0] != 0 or explicit_pads[1] != 0)
            {
                if(auto_pad == "SAME_UPPER")
                {
                    pads[6] = explicit_pads[0];
                    pads[7] = explicit_pads[1];
                }
                else if(auto_pad == "SAME_LOWER")
                {
                    pads[2] = explicit_pads[0];
                    pads[3] = explicit_pads[1];
                }

                // MaxPool
                if(op.mode == "max")
                {
                    ins = prog.add_instruction(op::pad{pads, std::numeric_limits<float>::lowest()},
                                               ins);
                }
                // AveragePool
                else
                {
                    ins = prog.add_instruction(op::pad{pads}, ins);
                }
            }

            op.padding_mode = op::padding_mode_t::same;
        }

        return ins;
    }

Paul's avatar
Paul committed
389
    instruction_ref
Paul's avatar
Paul committed
390
    parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
391
    {
392
        op::convolution op;
393
        auto l0 = args[0];
Paul's avatar
Paul committed
394
395
        if(contains(attributes, "pads"))
        {
Scott Thornton's avatar
Scott Thornton committed
396
            if(contains(attributes, "auto_pad"))
397
            {
398
399
400
401
402
                auto s = attributes["auto_pad"].s();
                if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
                }
403
            }
404
405
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
406
            if(padding.size() != 4)
407
            {
Paul's avatar
Paul committed
408
                MIGRAPHX_THROW("padding should have 4 values");
409
            }
Scott Thornton's avatar
Scott Thornton committed
410
            if(padding[0] != padding[2] || padding[1] != padding[3])
411
            {
412
413
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
Khalique's avatar
Khalique committed
414
                l0      = prog.add_instruction(op::pad{padding}, l0);
415
            }
416
417
418
419
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
420
            }
Paul's avatar
Paul committed
421
        }
Paul's avatar
Paul committed
422
423
424
425
426
427
428
429
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "dilations"))
        {
            copy(attributes["dilations"].ints(), op.dilation.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
430
        if(contains(attributes, "auto_pad"))
431
432
        {
            auto s = attributes["auto_pad"].s();
Scott Thornton's avatar
Scott Thornton committed
433
            if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
434
            {
Paul's avatar
Paul committed
435
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
436
437
            }

wsttiger's avatar
fixes  
wsttiger committed
438
            if(s.find("SAME") != std::string::npos)
439
            {
440
                op.padding_mode = op::padding_mode_t::same;
441
442
            }
        }
Khalique's avatar
Khalique committed
443
444
445
446
        if(contains(attributes, "group"))
        {
            op.group = parse_value(attributes.at("group")).at<int>();
        }
Paul's avatar
Paul committed
447
448
449
        if(args.size() == 3)
        {
            uint64_t axis = 1;
Khalique's avatar
Khalique committed
450
            auto l1       = prog.add_instruction(op, l0, args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
451
            auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
452
            return prog.add_instruction(op::add{}, l1, l2);
Paul's avatar
Paul committed
453
        }
454
        return prog.add_instruction(op, l0, args[1]);
Paul's avatar
Paul committed
455
    }
Paul's avatar
Paul committed
456

Paul's avatar
Paul committed
457
458
459
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
                                  std::vector<instruction_ref> args)
Paul's avatar
Paul committed
460
    {
Khalique's avatar
Khalique committed
461
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
462
        auto l0 = args[0];
Khalique's avatar
Khalique committed
463
        if(starts_with(name, "Global"))
464
        {
Khalique's avatar
Khalique committed
465
466
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
467
        }
468

Paul's avatar
Paul committed
469
470
        if(contains(attributes, "pads"))
        {
471
472
473
474
475
476
477
478
479
480
            if(contains(attributes, "auto_pad"))
            {
                auto s = attributes["auto_pad"].s();
                if(to_upper(s) != "NOTSET")
                {
                    MIGRAPHX_THROW(
                        "PARSE_POOLING: auto_pad and padding cannot be specified simultaneously");
                }
            }

481
482
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
483
            if(padding.size() != 4)
484
            {
485
                MIGRAPHX_THROW("PARSE_POOLING: padding should have 4 values");
486
            }
Scott Thornton's avatar
Scott Thornton committed
487
            if(padding[0] != padding[2] || padding[1] != padding[3])
488
            {
489
490
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
491
492
493
494
495
496
497
498
499
500
501
                // MaxPool
                if(op.mode == "max")
                {
                    l0 = prog.add_instruction(
                        op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
                }
                // AveragePool
                else
                {
                    l0 = prog.add_instruction(op::pad{padding}, l0);
                }
502
503
504
505
506
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
507
            }
Paul's avatar
Paul committed
508
        }
509

Paul's avatar
Paul committed
510
511
512
513
514
515
516
517
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "kernel_shape"))
        {
            copy(attributes["kernel_shape"].ints(), op.lengths.begin());
        }
518

Scott Thornton's avatar
Scott Thornton committed
519
        if(contains(attributes, "auto_pad"))
520
        {
521
522
            auto in_lens = args[0]->get_shape().lens();
            l0           = process_auto_pad_attribute(l0, attributes, op, in_lens);
523
524
        }

525
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
526
527
    }

Paul's avatar
Paul committed
528
    instruction_ref
Paul's avatar
Paul committed
529
    parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
530
    {
531
        op::reshape op;
Paul's avatar
Paul committed
532
533
        if(args.size() == 1)
        {
534
535
            literal s = parse_value(attributes.at("shape"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
536
537
538
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
539
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
540
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
541
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
542
        }
543

Shucai Xiao's avatar
Shucai Xiao committed
544
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
545
546
    }

Paul's avatar
Paul committed
547
    instruction_ref
Paul's avatar
Paul committed
548
    parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
549
    {
550
        int64_t axis = 1;
Paul's avatar
Paul committed
551
552
553
554
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
555
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
556
557
    }

558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
    instruction_ref
    parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::squeeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

    instruction_ref
    parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::unsqueeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

Scott Thornton's avatar
Scott Thornton committed
576
577
578
    instruction_ref
    parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
Shucai Xiao's avatar
Shucai Xiao committed
579
580
581
582
583
584
585
        // change to hande axis to be negative values
        if(!contains(attributes, "axis"))
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

        int axis = parse_value(attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
586
587
588
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
589

590
591
592
    instruction_ref
    parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
593
        int axis = 0;
594
595
596
597
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
598

599
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
600
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
601
602
    }

603
604
605
606
    instruction_ref
    parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::slice op;
Khalique's avatar
Khalique committed
607
        std::vector<size_t> dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
608
        size_t num_dims          = dims.size();
609
610
611
612
613
        if(contains(attributes, "axes"))
        {
            literal s = parse_value(attributes.at("axes"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Khalique's avatar
Khalique committed
614
615
616
617
618
        else
        {
            op.axes = std::vector<int64_t>(num_dims);
            std::iota(op.axes.begin(), op.axes.end(), 0);
        }
Khalique's avatar
Khalique committed
619

Khalique's avatar
Khalique committed
620
        if(contains(attributes, "ends"))
621
        {
Paul's avatar
Paul committed
622
            op.ends = get_indices(attributes.at("ends"));
623
        }
Khalique's avatar
Khalique committed
624
        if(contains(attributes, "starts"))
625
626
627
628
629
630
631
        {
            literal s = parse_value(attributes.at("starts"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
632
633
634
    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
                                   const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
635
    {
Shucai Xiao's avatar
Shucai Xiao committed
636
        literal v = parse_value(attributes.at("value"));
637
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
638
        if(v.get_shape().elements() == 0)
639
640
641
642
        {
            return prog.add_literal(literal{});
        }

643
644
645
        auto dim_size = attributes.at("value").t().dims_size();
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
646
        {
647
            migraphx::shape scalar_shape{v.get_shape().type()};
648
649
650
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
651
652
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
653

Paul's avatar
Paul committed
654
    instruction_ref
Paul's avatar
Paul committed
655
    parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
656
657
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
658
        float beta  = 1.0f;
Paul's avatar
Paul committed
659
660
661
662
663
664
665
666
        bool transa = false;
        bool transb = false;
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        if(contains(attributes, "beta"))
        {
667
            beta = parse_value(attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
668
669
670
671
672
673
674
675
676
        }
        if(contains(attributes, "transA"))
        {
            transa = parse_value(attributes.at("transA")).at<bool>();
        }
        if(contains(attributes, "transB"))
        {
            transb = parse_value(attributes.at("transB")).at<bool>();
        }
677
678
679
680
681
682

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

683
684
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
685
686
        if(args.size() == 3)
        {
687
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
688
            {
Shucai Xiao's avatar
Shucai Xiao committed
689
                auto out_lens   = l1->get_shape().lens();
690
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
691
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
692
693
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
694
                {
695
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
696
                }
697
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
698
            }
Paul's avatar
Paul committed
699
        }
700
701

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
702
703
    }

704
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
705
    parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
706
    {
Shucai Xiao's avatar
Shucai Xiao committed
707
708
        auto l0      = args[0];
        auto l1      = args[1];
709
710
711
712
713
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
714
        if(l0_lens.size() == 1)
715
716
717
718
719
720
721
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
722
        if(l1_lens.size() == 1)
723
724
725
726
727
728
729
730
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
731
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
732
733
734
735
736
737
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
738
            l0_broadcasted_lens = output_lens;
739
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
740
            l1_broadcasted_lens = output_lens;
741
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
742
            if(l0_lens != l0_broadcasted_lens)
743
744
745
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
746
            if(l1_lens != l1_broadcasted_lens)
747
748
749
750
751
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
752
        auto dot_res     = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
753
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
754
        if(is_a_prepended)
755
756
757
758
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
759
        if(is_b_appended)
760
761
762
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
763

764
765
766
        return dot_res;
    }

767
    instruction_ref
Paul's avatar
Paul committed
768
    parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
769
    {
Scott Thornton's avatar
Scott Thornton committed
770
771
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
772
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
773
774
775
776
777
778
        if(contains(attributes, "epsilon"))
        {
            epsilon = parse_value(attributes.at("epsilon")).at<float>();
        }
        if(contains(attributes, "momentum"))
        {
779
            momentum = parse_value(attributes.at("momentum")).at<float>();
780
781
782
        }
        if(contains(attributes, "spatial"))
        {
783
            bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
784
785
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
786
        }
Paul's avatar
Paul committed
787
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
788
        return prog.add_instruction(op, std::move(args));
789
790
    }

791
792
793
794
    instruction_ref parse_leaky_relu(const std::string&,
                                     attribute_map attributes,
                                     std::vector<instruction_ref> args)
    {
Khalique's avatar
Khalique committed
795
        float alpha = 0.01; // default alpha val for leaky relu
796
797
798
799
800
801
802
803
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
804
805
    instruction_ref
    parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
806
807
808
809
810
811
812
813
814
815
    {
        float alpha = 1.0; // default alpha val for elu
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
816
817
    instruction_ref
    parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
818
819
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
820
821
822
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
Khalique's avatar
Khalique committed
823
824
825
826
827
828
829
830
831
832
833
834
        if(contains(attributes, "alpha"))
            alpha = parse_value(attributes.at("alpha")).at<float>();
        if(contains(attributes, "beta"))
            beta = parse_value(attributes.at("beta")).at<float>();
        if(contains(attributes, "bias"))
            bias = parse_value(attributes.at("bias")).at<float>();
        if(contains(attributes, "size"))
            size = parse_value(attributes.at("size")).at<int>();
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
    instruction_ref parse_imagescaler(const std::string&,
                                      attribute_map attributes,
                                      std::vector<instruction_ref> args)
    {
        float scale = 1.0;
        std::vector<float> bias{};
        if(contains(attributes, "scale"))
        {
            scale = parse_value(attributes.at("scale")).at<float>();
        }

        if(contains(attributes, "bias"))
        {
            auto&& bias_floats = attributes["bias"].floats();
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
851
        auto input_lens = args.front()->get_shape().lens();
Khalique's avatar
Khalique committed
852

Khalique's avatar
Khalique committed
853
854
        auto scale_val = prog.add_literal(scale);
        auto bias_vals = prog.add_literal(
Paul's avatar
Paul committed
855
            migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
856

857
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
858
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
859
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
860
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
861
    }
Khalique's avatar
Khalique committed
862

Khalique's avatar
Khalique committed
863
864
    instruction_ref
    parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
865
866
867
868
869
870
871
    {
        std::vector<int64_t> perm{};
        if(contains(attributes, "perm"))
        {
            auto&& perm_vals = attributes["perm"].ints();
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
872
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
873
874
    }

Khalique's avatar
Khalique committed
875
876
877
878
879
880
881
882
883
884
    instruction_ref
    parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
        if(contains(attributes, "pads"))
        {
            auto&& pad_vals = attributes["pads"].ints();
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
885
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
886
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
887
888
889
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
Khalique's avatar
Khalique committed
890
891
892
893
894
895
896
897
898
899
900
901
        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }
        if(contains(attributes, "mode"))
        {
            auto mode = attributes.at("mode").s();
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
902
903
904
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
905
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
906
907
    {
        if(args.size() != 1)
908
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
    instruction_ref parse_constant_fill(const std::string&,
                                        attribute_map attributes,
                                        std::vector<instruction_ref> args)
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

        if(contains(attributes, "dtype"))
        {
            dtype = parse_value(attributes.at("dtype")).at<int>();
        }
Shucai Xiao's avatar
Shucai Xiao committed
933
        shape::type_t type = get_type(dtype);
934
935
936
937
938
939
940
941
942
943
944

        if(contains(attributes, "input_as_shape"))
        {
            input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
        }

        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
945
946
        if(contains(attributes, "extra_shape"))
        {
947
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
948
949
        }

950
951
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
952
            if(args.size() != 1)
953
            {
954
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
955
956
            }

Shucai Xiao's avatar
Shucai Xiao committed
957
958
            if(contains(attributes, "shape"))
            {
959
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
960
                               "at the same time");
961
962
            }

963
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
964
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
965

966
967
968
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
969
970
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
971
972
973
        }
        else if(input_as_shape == 0)
        {
Shucai Xiao's avatar
Shucai Xiao committed
974
975
            if(!contains(attributes, "shape"))
            {
976
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
977
978
979
            }

            literal ls = parse_value(attributes.at("shape"));
980
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
981
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
982
            migraphx::shape s{type, dims};
983
984
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
985
986
987
        }
        else
        {
988
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
989
990
991
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
992
993
994
    instruction_ref parse_constant_of_shape(const std::string&,
                                            attribute_map attributes,
                                            std::vector<instruction_ref> args)
995
996
    {
        literal l_val{};
Shucai Xiao's avatar
Shucai Xiao committed
997
        if(contains(attributes, "value"))
998
999
        {
            l_val = parse_value(attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1000
            if(l_val.get_shape().elements() != 1)
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1012

Shucai Xiao's avatar
Shucai Xiao committed
1013
        if(args.empty())
1014
        {
Shucai Xiao's avatar
Shucai Xiao committed
1015
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1016
1017
1018
        }
        else
        {
1019
1020
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1021
            if(args[0]->get_shape().elements() == 0)
1022
            {
1023
                s = migraphx::shape{type, {1}, {0}};
1024
            }
1025
1026
1027
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1028
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1029

1030
1031
1032
1033
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1034

Shucai Xiao's avatar
Shucai Xiao committed
1035
            literal l_out{};
1036
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1037
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1038
                // l_val contains only one element
1039
                std::vector<val_type> out_vec(s.elements(), val.front());
1040
1041
1042
1043
1044
1045
1046
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1047
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
1048
    parse_expand(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
1049
    {
Shucai Xiao's avatar
Shucai Xiao committed
1050
        auto in_lens             = args[0]->get_shape().lens();
1051
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1052
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1053
1054
1055
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1056
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1057
1058
    }

Shucai Xiao's avatar
Shucai Xiao committed
1059
    std::vector<instruction_ref>
Shucai Xiao's avatar
Shucai Xiao committed
1060
1061
1062
    parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
1063
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1064
1065
1066

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
1067
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1068
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1069
1070
1071
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1072
1073
1074
1075
1076
1077
1078
1079
1080
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

1081
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1082
1083
        if(direction == "bidirectional")
        {
1084
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1085
1086
1087
        }
        else if(direction == "reverse")
        {
1088
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1089
1090
        }

1091
        std::vector<std::string> vec_names{"tanh"};
1092
1093
1094
1095
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
1096
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1097
1098
1099
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1100
1101
        }

1102
1103
1104
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1105
        if(name_it != vec_names.end())
1106
1107
1108
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1109

Shucai Xiao's avatar
Shucai Xiao committed
1110
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1111
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1112
        // if only one actv function is provided, we use it in both
1113
        // forward and reverse direction
1114
        if(dirct == op::rnn_direction::bidirectional)
1115
        {
Shucai Xiao's avatar
Shucai Xiao committed
1116
            if(vec_names.size() == 1)
1117
1118
1119
1120
1121
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1122
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1123
1124
1125
1126
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1127

Shucai Xiao's avatar
Shucai Xiao committed
1128
1129
1130
1131
1132
1133
1134
        // To be added later
        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

1135
1136
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1137
        if(args.size() < 6)
1138
1139
1140
1141
1142
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1143
1144
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1145
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1146

1147
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1148
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1149

Shucai Xiao's avatar
Shucai Xiao committed
1150
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1151
1152
    }

1153
    std::vector<instruction_ref>
1154
1155
1156
1157
1158
1159
1160
    parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
1161
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1162
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1163
1164
1165
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1166
1167
1168
1169
1170
1171
1172
1173
1174
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

1175
        op::rnn_direction dirct = op::rnn_direction::forward;
1176
1177
        if(direction == "bidirectional")
        {
1178
            dirct = op::rnn_direction::bidirectional;
1179
1180
1181
        }
        else if(direction == "reverse")
        {
1182
            dirct = op::rnn_direction::reverse;
1183
1184
        }

1185
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1186
1187
        if(contains(attributes, "activations"))
        {
1188
            auto names = attributes.at("activations").strings();
1189
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1190
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1191
1192
1193
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1194
1195
        }

1196
        // need 4 activation functions
1197
        if(dirct == op::rnn_direction::bidirectional)
1198
        {
Shucai Xiao's avatar
Shucai Xiao committed
1199
            // 4 activation functions are used in the bidirectional
1200
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1201
1202
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1203
1204
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1205
1206
1207
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1208
            if(vec_names.size() == 1)
1209
            {
1210
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1211
            }
1212
            else if(vec_names.size() == 2)
1213
            {
1214
1215
1216
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1217
            }
1218
            else if(vec_names.size() == 3)
1219
            {
1220
                vec_names.push_back(vec_names.at(2));
1221
1222
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1223
        else
1224
        {
1225
            if(vec_names.size() == 1)
1226
            {
1227
                vec_names.push_back(vec_names.at(0));
1228
1229
1230
            }
        }

1231
1232
1233
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1234
        if(name_it != vec_names.end())
1235
1236
1237
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1238

Shucai Xiao's avatar
Shucai Xiao committed
1239
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1240
1241
1242
1243
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1244
1245
1246
1247
1248
1249
1250
1251

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int linear_before_reset = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1252
        if(contains(attributes, "linear_before_reset"))
1253
1254
1255
1256
        {
            linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1257
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1258
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1259
1260
1261
1262
1263
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1264
1265
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1266
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1267
            std::move(args));
1268
1269

        // second output for last gru output
1270
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1271

Shucai Xiao's avatar
Shucai Xiao committed
1272
        return {hidden_states, last_output};
1273
1274
    }

Shucai Xiao's avatar
Shucai Xiao committed
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
    std::vector<instruction_ref>
    parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1297
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1298
1299
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1300
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1301
1302
1303
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1304
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1305
        }
Shucai Xiao's avatar
Shucai Xiao committed
1306
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1307
        {
Shucai Xiao's avatar
Shucai Xiao committed
1308
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1309
1310
1311
1312
1313
1314
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1315
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
Shucai Xiao's avatar
Shucai Xiao committed
1316
1317
1318
1319
1320
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1321
1322
1323
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1324
1325
1326
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1327
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1328
1329
1330
1331
1332
1333
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1334
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1335
1336
1337
1338
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1339
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1340
1341
1342
1343
1344
1345
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1346
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1347
1348
1349

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1350
1351
1352
1353
1354
1355
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1356
1357
1358
1359
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1360
1361
1362
1363
1364
1365
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1366
1367
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1368
1369
1370
1371
1372
1373
1374
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1375
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1376

Shucai Xiao's avatar
Shucai Xiao committed
1377
1378
1379
1380
1381
1382
1383
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1384
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1385

Shucai Xiao's avatar
Shucai Xiao committed
1386
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1387
1388
1389
1390
1391
1392
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1393
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1394
1395
1396

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1397
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1398
1399
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1400
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1401
1402
1403
            }
        }

1404
1405
1406
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1407
        if(name_it != vec_names.end())
1408
1409
1410
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1411
1412

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1413
1414
1415
1416
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int input_forget = 0;
        if(contains(attributes, "input_forget"))
        {
            input_forget = parse_value(attributes.at("input_forget")).at<int>();
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1434
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1435
1436
1437
1438
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1439
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1440
1441

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1442
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1443
1444
1445
1446
1447
1448

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1449

Shucai Xiao's avatar
Shucai Xiao committed
1450
    template <class T>
Shucai Xiao's avatar
Shucai Xiao committed
1451
    instruction_ref parse_reduce_oper(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
1452
1453
                                      attribute_map attributes,
                                      std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1454
1455
1456
1457
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
1458
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
1459
1460
1461
1462
1463
        std::iota(axes.begin(), axes.end(), 0);
        if(contains(attributes, "axes"))
        {
            axes.clear();
            auto&& attr_axes = attributes["axes"].ints();
1464
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
        }

        int keep_dims = 1;
        if(contains(attributes, "keepdims"))
        {
            keep_dims = parse_value(attributes.at("keepdims")).at<int>();
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1475
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1476
1477
1478
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
1479
            auto ins = prog.add_instruction(T{axes}, std::move(args));
1480
            return prog.add_instruction(op::squeeze{axes}, ins);
1481
1482
        }
    }
1483

Shucai Xiao's avatar
Shucai Xiao committed
1484
1485
    instruction_ref
    parse_cast(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
1486
    {
Shucai Xiao's avatar
Shucai Xiao committed
1487
        if(!contains(attributes, "to"))
1488
1489
1490
1491
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

Shucai Xiao's avatar
Shucai Xiao committed
1492
        int to_type        = parse_value(attributes.at("to")).at<int>();
1493
1494
1495
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
1496

Paul's avatar
Paul committed
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1509
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1510
1511
1512
1513
1514
1515
        }
    }

    void parse_graph(const onnx::GraphProto& graph)
    {
        nodes = get_nodes(graph);
1516
        for(auto&& f : graph.initializer())
1517
1518
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
1519
1520
1521
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1522
1523
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
1524
1525
            {
                // TODO: Get shape of input parameter
1526
                shape s            = parse_type(input.type(), batch_size);
1527
1528
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1529
        }
Paul's avatar
Paul committed
1530
        for(auto&& output : graph.output())
Paul's avatar
Paul committed
1531
        {
Paul's avatar
Paul committed
1532
            this->parse_node(output.name());
Paul's avatar
Paul committed
1533
        }
Shucai Xiao's avatar
Shucai Xiao committed
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546

        // For now, the last output with a valid name is considered
        // as the program output, and add an identity instruction at
        // the program end
        auto prog_output = graph.output();
        auto oit         = std::find_if(prog_output.rbegin(), prog_output.rend(), [](auto& node) {
            return !node.name().empty();
        });

        if(instructions.count(oit->name()) > 0)
        {
            prog.add_instruction(op::identity{}, instructions[oit->name()]);
        }
Paul's avatar
Paul committed
1547
1548
    }

Shucai Xiao's avatar
Shucai Xiao committed
1549
    void parse_undefined(const std::string& name)
1550
    {
Shucai Xiao's avatar
Shucai Xiao committed
1551
        auto ins           = prog.add_instruction(op::undefined{});
1552
1553
1554
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1555
    void parse_node(const std::string& name)
Paul's avatar
Paul committed
1556
    {
Paul's avatar
Paul committed
1557
        if(name.empty())
Paul's avatar
Paul committed
1558
            MIGRAPHX_THROW("Onnx node must have a name");
Paul's avatar
Paul committed
1559
1560
1561
1562
1563
1564
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1565
                if(input.empty())
Paul's avatar
Paul committed
1566
                {
Shucai Xiao's avatar
Shucai Xiao committed
1567
                    this->parse_undefined(input);
Paul's avatar
Paul committed
1568
                }
Shucai Xiao's avatar
Shucai Xiao committed
1569
                else if(nodes.count(input) > 0)
Paul's avatar
Paul committed
1570
                {
Shucai Xiao's avatar
Shucai Xiao committed
1571
1572
                    assert(name != input);
                    this->parse_node(input);
Paul's avatar
Paul committed
1573
                }
1574
                args.push_back(instructions.at(input));
Paul's avatar
Paul committed
1575
            }
Paul's avatar
Paul committed
1576
            std::vector<instruction_ref> result;
Paul's avatar
Paul committed
1577
1578
            if(ops.count(node.op_type()) == 0)
            {
1579
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
Paul's avatar
Paul committed
1580
1581
1582
            }
            else
            {
Paul's avatar
Paul committed
1583
                result = ops[node.op_type()](get_attributes(node), args);
Paul's avatar
Paul committed
1584
            }
Paul's avatar
Paul committed
1585
            // Even no output nodes produce output in migraphx
Paul's avatar
Paul committed
1586
            if(node.output().empty() and result.size() == 1)
Paul's avatar
Paul committed
1587
1588
            {
                instructions[name] = result.front();
Paul's avatar
Paul committed
1589
1590
1591
            }
            else
            {
Shucai Xiao's avatar
Shucai Xiao committed
1592
1593
1594
1595
                assert(node.output().size() <= result.size());
                std::transform(node.output().begin(),
                               node.output().end(),
                               result.begin(),
Paul's avatar
Paul committed
1596
                               std::inserter(instructions, instructions.end()),
Shucai Xiao's avatar
Shucai Xiao committed
1597
                               [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
            }
        }
    }

    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    static node_map get_nodes(const onnx::GraphProto& graph)
    {
        std::unordered_map<std::string, onnx::NodeProto> result;
Paul's avatar
Paul committed
1615
        std::size_t n = 0;
Paul's avatar
Paul committed
1616
1617
        for(auto&& node : graph.node())
        {
Paul's avatar
Paul committed
1618
            if(node.output().empty())
Paul's avatar
Paul committed
1619
            {
Paul's avatar
Paul committed
1620
                if(node.name().empty())
Paul's avatar
Paul committed
1621
1622
1623
1624
1625
1626
1627
1628
1629
                {
                    result["migraphx_unamed_node_" + std::to_string(n)] = node;
                    n++;
                }
                else
                {
                    result[node.name()] = node;
                }
            }
Paul's avatar
Paul committed
1630
1631
1632
1633
1634
1635
1636
1637
            for(auto&& output : node.output())
            {
                result[output] = node;
            }
        }
        return result;
    }

Paul's avatar
Paul committed
1638
1639
1640
1641
1642
1643
    static std::vector<int64_t> get_indices(const onnx::AttributeProto& attr)
    {
        std::vector<int64_t> result;
        literal s = parse_value(attr);
        s.visit([&](auto v) { copy(v, std::back_inserter(result)); });
        // Clamp large indices to -1
Paul's avatar
Paul committed
1644
1645
1646
1647
1648
        std::replace_if(
            result.begin(),
            result.end(),
            [](auto x) { return x > int64_t{std::numeric_limits<std::int32_t>::max()} / 2; },
            -1);
Paul's avatar
Paul committed
1649
1650
1651
        return result;
    }

Paul's avatar
Paul committed
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
1666
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1667
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
1668
1669
1670
1671
1672
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
Paul's avatar
Paul committed
1673
1674
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
1675
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
1676
1677
1678
1679
1680
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
1681
1682
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
1683
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
1684
1685
            switch(t.data_type())
            {
1686
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
1687
1688
1689
1690
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
1691
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
1692
1693
1694
1695
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
1696
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
1697
1698
1699
1700
1701
1702
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
1703
1704
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
1705
            MIGRAPHX_THROW("Invalid tensor type");
1706
        }
Paul's avatar
Paul committed
1707
1708
1709
1710
1711
1712
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
1713
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
1714
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1715
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
1716
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
1717
1718
1719
1720
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
1721
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
1722
        {
Khalique's avatar
Khalique committed
1723
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
1724
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
1725
1726
1727
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
1728
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1729
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1730
        }
Paul's avatar
Paul committed
1731
1732
1733
1734
1735
1736
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
1737
1738
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
1739
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
1740
1741
    }

Khalique's avatar
Khalique committed
1742
    static literal
1743
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
1744
    {
Khalique's avatar
Khalique committed
1745
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
1746
        if(dims.empty())
1747
            return literal{{shape_type}, data};
1748
1749
1750
        return literal{{shape_type, dims}, data};
    }

1751
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
1752
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
1753
1754
    {
        if(dims.empty())
1755
            return literal{{shape_type}, data.begin(), data.end()};
1756
        return literal{{shape_type, dims}, data.begin(), data.end()};
1757
1758
    }

1759
    static shape parse_type(const onnx::TypeProto& t, const unsigned int batch_size)
Paul's avatar
Paul committed
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
1770
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
1771
1772
1773
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
1774
1775
1776
1777
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
1778
1779
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
1780
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
1781
1782
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
1783
        auto&& tensor_dims = t.tensor_type().shape().dim();
1784
1785
1786
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
1787
1788
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
1789
                           {
1790
1791
1792
                               if(static_cast<int>(d.dim_value()) <= 0)
                                   return batch_size;
                               return d.dim_value();
1793
                           }
1794
                           return batch_size;
1795
                       });
1796
1797
1798
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
1799
1800
        return {shape_type, dims};
    }
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
1823
1824
1825

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
1826
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
1827
1828
1829
1830
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
1831
1832
};

1833
program parse_onnx(const std::string& name, onnx_options options)
Paul's avatar
Paul committed
1834
1835
1836
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    onnx_parser parser;
1837
    parser.batch_size = options.batch_size;
Paul's avatar
Paul committed
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
    return std::move(parser.prog);
}

Paul's avatar
Paul committed
1855
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1856
} // namespace migraphx