onnx.cpp 58.6 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
Paul's avatar
Paul committed
19
20

namespace migraphx {
Paul's avatar
Paul committed
21
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
22
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
    using node_map      = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
27
28
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
29
30
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
Scott Thornton's avatar
Scott Thornton committed
31
    program prog    = program();
32
    bool is_pytorch = false;
Paul's avatar
Paul committed
33
34

    std::unordered_map<std::string, op_func> ops;
35
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
36
37
38

    onnx_parser()
    {
Khalique's avatar
Khalique committed
39
        add_generic_op("Relu", op::relu{});
Khalique's avatar
Khalique committed
40
41
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Abs", op::abs{});
Shucai Xiao's avatar
Shucai Xiao committed
42
43
        add_generic_op("Exp", op::exp{});
        add_generic_op("Log", op::log{});
Khalique's avatar
Khalique committed
44
45
        // disable dropout for inference
        add_generic_op("Dropout", op::identity{});
Khalique's avatar
Khalique committed
46
        add_generic_op("Identity", op::identity{});
Shucai Xiao's avatar
Shucai Xiao committed
47
48
49
        add_generic_op("Sin", op::sin{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Tan", op::tan{});
50
51
        add_generic_op("Sinh", op::sinh{});
        add_generic_op("Cosh", op::cosh{});
52
        add_generic_op("Tanh", op::tanh{});
53
54
55
        add_generic_op("Asin", op::asin{});
        add_generic_op("Acos", op::acos{});
        add_generic_op("Atan", op::atan{});
Paul's avatar
Paul committed
56

Khalique's avatar
Khalique committed
57
58
59
60
61
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
        add_binary_op("Sub", op::sub{});

Khalique's avatar
Khalique committed
62
63
64
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
65

66
67
        add_mem_op("ArgMax", &onnx_parser::parse_argmax);
        add_mem_op("ArgMin", &onnx_parser::parse_argmin);
Khalique's avatar
Khalique committed
68
        add_mem_op("Clip", &onnx_parser::parse_clip);
Khalique's avatar
Khalique committed
69
        add_mem_op("LRN", &onnx_parser::parse_lrn);
Khalique's avatar
Khalique committed
70
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
71
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
Khalique's avatar
Khalique committed
72
        add_mem_op("Elu", &onnx_parser::parse_elu);
Paul's avatar
Paul committed
73
74
        add_mem_op("Constant", &onnx_parser::parse_constant);
        add_mem_op("Conv", &onnx_parser::parse_conv);
Paul's avatar
Paul committed
75
76
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
77
78
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
Paul's avatar
Paul committed
79
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
Paul's avatar
Paul committed
80
81
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
82
        add_mem_op("MatMul", &onnx_parser::parse_matmul);
83
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
Paul's avatar
Paul committed
84
        add_mem_op("Softmax", &onnx_parser::parse_softmax);
Shucai Xiao's avatar
Shucai Xiao committed
85
        add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax);
86
87
88
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("Slice", &onnx_parser::parse_slice);
Scott Thornton's avatar
Scott Thornton committed
89
        add_mem_op("Concat", &onnx_parser::parse_concat);
90
91
92
        add_mem_op("Gather", &onnx_parser::parse_gather);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
Khalique's avatar
Khalique committed
93
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
Shucai Xiao's avatar
Shucai Xiao committed
94
        add_mem_op("RNN", &onnx_parser::parse_rnn);
95
        add_mem_op("GRU", &onnx_parser::parse_gru);
Shucai Xiao's avatar
Shucai Xiao committed
96
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
Khalique's avatar
Khalique committed
97
        add_mem_op("Pad", &onnx_parser::parse_pad);
98
99
100
101
102
103
104

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
105
106
107
108
109
110
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
111
112
113
114
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
115
116
117
118
119
120
121
122
123
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
124
125
126
127
128
129
130
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
131
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
132
133
134
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
135

136
    template <class T>
Khalique's avatar
Khalique committed
137
    void add_binary_op(std::string name, T x)
138
    {
Paul's avatar
Paul committed
139
        add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
140
            if(args.size() != 2)
Paul's avatar
Paul committed
141
                MIGRAPHX_THROW("binary operators should have 2 operands");
142
            if(contains(attributes, "broadcast") and contains(attributes, "axis"))
143
144
145
146
            {
                uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
                if(broadcasted != 0)
                {
147
                    uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
148
149
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
150
151
                    return prog.add_instruction(x, args[0], l);
                }
152
                return prog.add_instruction(x, args);
153
            }
Paul's avatar
Paul committed
154
            else
155
            {
Khalique's avatar
Khalique committed
156
                return add_broadcastable_binary_op(args[0], args[1], x);
157
158
159
160
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
161
162
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
163
164
165
166
167
168
169
170
171
172
173
174
175
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
176
        if(s0.size() > s1.size())
177
178
179
180
181
182
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
183
184
185
186
187
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
                       [](auto a, auto b) { return std::max(a, b); });
188
189
190
191

        return out_lens;
    }

Khalique's avatar
Khalique committed
192
193
194
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
195
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
196
197
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
198
199
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
200
            auto out_lens = compute_broadcasted_lens(s0, s1);
Shucai Xiao's avatar
Shucai Xiao committed
201
202
            auto l0       = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
            auto l1       = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
Khalique's avatar
Khalique committed
203
204
205
206
207
208
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
209
210
    }

Paul's avatar
Paul committed
211
    template <class T>
Paul's avatar
Paul committed
212
213
    void add_generic_op(std::string name, T x)
    {
Paul's avatar
Paul committed
214
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
215
216
217
218
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
219
    template <class T>
Khalique's avatar
Khalique committed
220
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
221
    {
Paul's avatar
Paul committed
222
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
223
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
224
225
226
227
228
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
229
        });
Khalique's avatar
Khalique committed
230
231
    }

Khalique's avatar
Khalique committed
232
233
234
    instruction_ref parse_clip(const std::string&,
                               const attribute_map& attributes,
                               std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
235
236
237
238
239
240
241
242
243
244
245
246
247
    {
        op::clip op;
        if(contains(attributes, "max"))
        {
            op.max_val = parse_value(attributes.at("max")).at<float>();
        }
        if(contains(attributes, "min"))
        {
            op.min_val = parse_value(attributes.at("min")).at<float>();
        }
        return prog.add_instruction(op, std::move(args));
    }

Paul's avatar
Paul committed
248
    instruction_ref
Paul's avatar
Paul committed
249
    parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
250
251
    {
        auto dims = args.front()->get_shape().lens();
Scott Thornton's avatar
Scott Thornton committed
252
253
        auto r =
            prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
254
255
        auto s = prog.add_instruction(op::softmax{}, r);
        return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
Paul's avatar
Paul committed
256
257
    }

Shucai Xiao's avatar
Shucai Xiao committed
258
259
260
    instruction_ref parse_logsoftmax(const std::string&,
                                     const attribute_map& attributes,
                                     std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
261
262
263
264
265
266
267
268
269
270
    {
        int axis = 1;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

        return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
    }

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    instruction_ref parse_argmax(const std::string&,
                                     const attribute_map& attributes,
                                     std::vector<instruction_ref> args)
    {
        int axis = 0;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

        return prog.add_instruction(op::argmax{axis}, std::move(args));
    }

    instruction_ref parse_argmin(const std::string&,
                                     const attribute_map& attributes,
                                     std::vector<instruction_ref> args)
    {
        int axis = 0;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

        return prog.add_instruction(op::argmin{axis}, std::move(args));
    }


Paul's avatar
Paul committed
298
    instruction_ref
Paul's avatar
Paul committed
299
    parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
300
    {
301
        op::convolution op;
302
        auto l0 = args[0];
Paul's avatar
Paul committed
303
304
        if(contains(attributes, "pads"))
        {
Scott Thornton's avatar
Scott Thornton committed
305
            if(contains(attributes, "auto_pad"))
306
            {
Paul's avatar
Paul committed
307
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
308
            }
309
310
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
311
            if(padding.size() != 4)
312
            {
Paul's avatar
Paul committed
313
                MIGRAPHX_THROW("padding should have 4 values");
314
            }
Scott Thornton's avatar
Scott Thornton committed
315
            if(padding[0] != padding[2] || padding[1] != padding[3])
316
            {
317
318
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
Khalique's avatar
Khalique committed
319
                l0      = prog.add_instruction(op::pad{padding}, l0);
320
            }
321
322
323
324
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
325
            }
Paul's avatar
Paul committed
326
        }
Paul's avatar
Paul committed
327
328
329
330
331
332
333
334
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "dilations"))
        {
            copy(attributes["dilations"].ints(), op.dilation.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
335
        if(contains(attributes, "auto_pad"))
336
337
        {
            auto s = attributes["auto_pad"].s();
Scott Thornton's avatar
Scott Thornton committed
338
            if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
339
            {
Paul's avatar
Paul committed
340
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
341
342
            }

wsttiger's avatar
fixes  
wsttiger committed
343
            if(s.find("SAME") != std::string::npos)
344
            {
345
                op.padding_mode = op::padding_mode_t::same;
346
347
            }
        }
Khalique's avatar
Khalique committed
348
349
350
351
        if(contains(attributes, "group"))
        {
            op.group = parse_value(attributes.at("group")).at<int>();
        }
Paul's avatar
Paul committed
352
353
354
355
        if(args.size() == 3)
        {
            uint64_t axis = 1;
            auto l1       = prog.add_instruction(op, args[0], args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
356
            auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
357
            return prog.add_instruction(op::add{}, l1, l2);
Paul's avatar
Paul committed
358
        }
359
        return prog.add_instruction(op, l0, args[1]);
Paul's avatar
Paul committed
360
    }
Paul's avatar
Paul committed
361

Paul's avatar
Paul committed
362
363
364
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
                                  std::vector<instruction_ref> args)
Paul's avatar
Paul committed
365
    {
Khalique's avatar
Khalique committed
366
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
367
        auto l0 = args[0];
Khalique's avatar
Khalique committed
368
        if(starts_with(name, "Global"))
369
        {
Khalique's avatar
Khalique committed
370
371
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
372
        }
Paul's avatar
Paul committed
373
374
        if(contains(attributes, "pads"))
        {
375
376
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
377
            if(padding.size() != 4)
378
            {
Paul's avatar
Paul committed
379
                MIGRAPHX_THROW("padding should have 4 values");
380
            }
Scott Thornton's avatar
Scott Thornton committed
381
            if(padding[0] != padding[2] || padding[1] != padding[3])
382
            {
383
384
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
Khalique's avatar
Khalique committed
385
386
                l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
                                          l0);
387
388
389
390
391
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
392
            }
Paul's avatar
Paul committed
393
394
395
396
397
398
399
400
401
        }
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "kernel_shape"))
        {
            copy(attributes["kernel_shape"].ints(), op.lengths.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
402
        if(contains(attributes, "auto_pad"))
403
404
        {
            auto s = attributes["auto_pad"].s();
405
            if(s.find("SAME_UPPER") == std::string::npos)
406
            {
407
                MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling");
408
            }
409
            op.padding_mode = op::padding_mode_t::same;
410
411
        }

412
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
413
414
    }

Paul's avatar
Paul committed
415
    instruction_ref
Paul's avatar
Paul committed
416
    parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
417
    {
418
        op::reshape op;
Paul's avatar
Paul committed
419
420
421
422
423
424
425
        if(args.size() == 1)
        {
            literal s = parse_value(attributes.at("shape"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
426
            auto s = args[1]->eval();
Paul's avatar
Paul committed
427
            if(s.empty())
Paul's avatar
Paul committed
428
                MIGRAPHX_THROW("Dynamic shape is not supported.");
Paul's avatar
Paul committed
429
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
430
        }
Paul's avatar
Paul committed
431
432
433
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
434
    instruction_ref
Paul's avatar
Paul committed
435
    parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
436
    {
437
        uint64_t axis = 1;
Paul's avatar
Paul committed
438
439
440
441
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
442
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
443
444
    }

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    instruction_ref
    parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::squeeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

    instruction_ref
    parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::unsqueeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

Scott Thornton's avatar
Scott Thornton committed
463
464
465
466
467
468
469
    instruction_ref
    parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::size_t axis = parse_value(attributes.at("axis")).at<int>();
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
470

471
472
473
    instruction_ref
    parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
474
        int axis = 0;
475
476
477
478
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
479
        op::gather op{axis};
480
481
482
        return prog.add_instruction(op, std::move(args));
    }

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
    instruction_ref
    parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::slice op;
        if(contains(attributes, "axes"))
        {
            literal s = parse_value(attributes.at("axes"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
        {
            literal s = parse_value(attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
        }
        {
            literal s = parse_value(attributes.at("starts"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
503
504
505
    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
                                   const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
506
    {
Shucai Xiao's avatar
Shucai Xiao committed
507
        literal v     = parse_value(attributes.at("value"));
508
509
510
        auto dim_size = attributes.at("value").t().dims_size();
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
511
        {
512
            migraphx::shape scalar_shape{v.get_shape().type()};
513
514
515
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
516
517
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
518

Paul's avatar
Paul committed
519
    instruction_ref
Paul's avatar
Paul committed
520
    parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
521
522
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
523
        float beta  = 1.0f;
Paul's avatar
Paul committed
524
525
526
527
528
529
530
531
        bool transa = false;
        bool transb = false;
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        if(contains(attributes, "beta"))
        {
532
            beta = parse_value(attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
533
534
535
536
537
538
539
540
541
        }
        if(contains(attributes, "transA"))
        {
            transa = parse_value(attributes.at("transA")).at<bool>();
        }
        if(contains(attributes, "transB"))
        {
            transb = parse_value(attributes.at("transB")).at<bool>();
        }
542
543
544
545
546
547

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

548
549
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
550
551
        if(args.size() == 3)
        {
552
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
553
            {
Shucai Xiao's avatar
Shucai Xiao committed
554
                auto out_lens   = l1->get_shape().lens();
555
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
556
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
557
558
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
559
                {
560
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
561
                }
562
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
563
            }
Paul's avatar
Paul committed
564
        }
565
566

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
567
568
    }

569
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
570
    parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
571
    {
Shucai Xiao's avatar
Shucai Xiao committed
572
573
        auto l0      = args[0];
        auto l1      = args[1];
574
575
576
577
578
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
579
        if(l0_lens.size() == 1)
580
581
582
583
584
585
586
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
587
        if(l1_lens.size() == 1)
588
589
590
591
592
593
594
595
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
596
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
597
598
599
600
601
602
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
603
            l0_broadcasted_lens = output_lens;
604
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
605
            l1_broadcasted_lens = output_lens;
606
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
607
            if(l0_lens != l0_broadcasted_lens)
608
609
610
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
611
            if(l1_lens != l1_broadcasted_lens)
612
613
614
615
616
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
617
        auto dot_res     = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
618
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
619
        if(is_a_prepended)
620
621
622
623
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
624
        if(is_b_appended)
625
626
627
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
628

629
630
631
        return dot_res;
    }

632
    instruction_ref
Paul's avatar
Paul committed
633
    parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
634
    {
Scott Thornton's avatar
Scott Thornton committed
635
636
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
637
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
Scott Thornton's avatar
Scott Thornton committed
638
        bool is_test                                      = false;
639
640
641
642
643
644
        if(contains(attributes, "epsilon"))
        {
            epsilon = parse_value(attributes.at("epsilon")).at<float>();
        }
        if(contains(attributes, "momentum"))
        {
645
            momentum = parse_value(attributes.at("momentum")).at<float>();
646
647
648
        }
        if(contains(attributes, "is_test"))
        {
wsttiger's avatar
wsttiger committed
649
            is_test = parse_value(attributes.at("is_test")).at<uint64_t>() > 0;
650
651
652
        }
        if(contains(attributes, "spatial"))
        {
653
            bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
654
655
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
656
        }
Paul's avatar
Paul committed
657
        (void)is_test;
Paul's avatar
Paul committed
658
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
659
        return prog.add_instruction(op, std::move(args));
660
661
    }

662
663
664
665
    instruction_ref parse_leaky_relu(const std::string&,
                                     attribute_map attributes,
                                     std::vector<instruction_ref> args)
    {
Khalique's avatar
Khalique committed
666
        float alpha = 0.01; // default alpha val for leaky relu
667
668
669
670
671
672
673
674
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
675
676
    instruction_ref
    parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
677
678
679
680
681
682
683
684
685
686
    {
        float alpha = 1.0; // default alpha val for elu
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
687
688
    instruction_ref
    parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
689
690
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
691
692
693
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
Khalique's avatar
Khalique committed
694
695
696
697
698
699
700
701
702
703
704
705
        if(contains(attributes, "alpha"))
            alpha = parse_value(attributes.at("alpha")).at<float>();
        if(contains(attributes, "beta"))
            beta = parse_value(attributes.at("beta")).at<float>();
        if(contains(attributes, "bias"))
            bias = parse_value(attributes.at("bias")).at<float>();
        if(contains(attributes, "size"))
            size = parse_value(attributes.at("size")).at<int>();
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    instruction_ref parse_imagescaler(const std::string&,
                                      attribute_map attributes,
                                      std::vector<instruction_ref> args)
    {
        float scale = 1.0;
        std::vector<float> bias{};
        if(contains(attributes, "scale"))
        {
            scale = parse_value(attributes.at("scale")).at<float>();
        }

        if(contains(attributes, "bias"))
        {
            auto&& bias_floats = attributes["bias"].floats();
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
722
        auto input_lens = args.front()->get_shape().lens();
Khalique's avatar
Khalique committed
723

Khalique's avatar
Khalique committed
724
725
        auto scale_val = prog.add_literal(scale);
        auto bias_vals = prog.add_literal(
Paul's avatar
Paul committed
726
            migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
727

728
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
729
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
730
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
731
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
732
    }
Khalique's avatar
Khalique committed
733

Khalique's avatar
Khalique committed
734
735
    instruction_ref
    parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
736
737
738
739
740
741
742
    {
        std::vector<int64_t> perm{};
        if(contains(attributes, "perm"))
        {
            auto&& perm_vals = attributes["perm"].ints();
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
743
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
744
745
    }

Khalique's avatar
Khalique committed
746
747
748
749
750
751
752
753
754
755
    instruction_ref
    parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
        if(contains(attributes, "pads"))
        {
            auto&& pad_vals = attributes["pads"].ints();
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
756
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
757
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
758
759
760
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
Khalique's avatar
Khalique committed
761
762
763
764
765
766
767
768
769
770
771
772
        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }
        if(contains(attributes, "mode"))
        {
            auto mode = attributes.at("mode").s();
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
773
774
775
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
776
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
777
778
    {
        if(args.size() != 1)
779
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
    instruction_ref parse_constant_fill(const std::string&,
                                        attribute_map attributes,
                                        std::vector<instruction_ref> args)
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

        if(contains(attributes, "dtype"))
        {
            dtype = parse_value(attributes.at("dtype")).at<int>();
        }
        migraphx::shape::type_t type = get_type(dtype);

        if(contains(attributes, "input_as_shape"))
        {
            input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
        }

        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
816
817
        if(contains(attributes, "extra_shape"))
        {
818
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
819
820
        }

821
822
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
823
            if(args.size() != 1)
824
            {
825
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
826
827
            }

Shucai Xiao's avatar
Shucai Xiao committed
828
829
            if(contains(attributes, "shape"))
            {
830
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
831
                               "at the same time");
832
833
            }

834
835
836
            migraphx::argument in = args[0]->eval();
            if(in.empty())
            {
837
                MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
838
            }
839

840
841
842
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
843
844
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
845
846
847
        }
        else if(input_as_shape == 0)
        {
Shucai Xiao's avatar
Shucai Xiao committed
848
849
            if(!contains(attributes, "shape"))
            {
850
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
851
852
853
            }

            literal ls = parse_value(attributes.at("shape"));
854
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
855
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
856
            migraphx::shape s{type, dims};
857
858
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
859
860
861
        }
        else
        {
862
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
863
864
865
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
866
    std::vector<instruction_ref>
Shucai Xiao's avatar
Shucai Xiao committed
867
868
869
    parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
870
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
871
872
873

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
874
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
875
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
876
877
878
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
879
880
881
882
883
884
885
886
887
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

888
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
889
890
        if(direction == "bidirectional")
        {
891
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
892
893
894
        }
        else if(direction == "reverse")
        {
895
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
896
897
        }

898
        std::vector<std::string> vec_names{"tanh"};
899
900
901
902
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
903
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
904
905
906
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
907
908
        }

909
910
911
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
912
        if(name_it != vec_names.end())
913
914
915
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
916

Shucai Xiao's avatar
Shucai Xiao committed
917
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
918
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
919
        // if only one actv function is provided, we use it in both
920
        // forward and reverse direction
921
        if(dirct == op::rnn_direction::bidirectional)
922
        {
Shucai Xiao's avatar
Shucai Xiao committed
923
            if(vec_names.size() == 1)
924
925
926
927
928
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
929
930
931
        std::vector<operation> vec_actv_funcs(vec_names.size());
        std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& fn) {
            return map_actv_funcs[fn];
932
        });
Shucai Xiao's avatar
Shucai Xiao committed
933

Shucai Xiao's avatar
Shucai Xiao committed
934
935
936
937
938
939
940
        // To be added later
        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

941
942
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
943
        if(args.size() < 6)
944
945
946
947
948
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
949
950
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
951
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
952

953
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
954
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
955

Shucai Xiao's avatar
Shucai Xiao committed
956
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
957
958
    }

959
    std::vector<instruction_ref>
960
961
962
963
964
965
966
    parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
967
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
968
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
969
970
971
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
972
973
974
975
976
977
978
979
980
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

981
        op::rnn_direction dirct = op::rnn_direction::forward;
982
983
        if(direction == "bidirectional")
        {
984
            dirct = op::rnn_direction::bidirectional;
985
986
987
        }
        else if(direction == "reverse")
        {
988
            dirct = op::rnn_direction::reverse;
989
990
        }

991
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
992
993
        if(contains(attributes, "activations"))
        {
994
            auto names = attributes.at("activations").strings();
995
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
996
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
997
998
999
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1000
1001
        }

1002
        // need 4 activation functions
1003
        if(dirct == op::rnn_direction::bidirectional)
1004
        {
Shucai Xiao's avatar
Shucai Xiao committed
1005
            // 4 activation functions are used in the bidirectional
1006
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1007
1008
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1009
1010
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1011
1012
1013
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1014
            if(vec_names.size() == 1)
1015
            {
1016
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1017
            }
1018
            else if(vec_names.size() == 2)
1019
            {
1020
1021
1022
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1023
            }
1024
            else if(vec_names.size() == 3)
1025
            {
1026
                vec_names.push_back(vec_names.at(2));
1027
1028
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1029
        else
1030
        {
1031
            if(vec_names.size() == 1)
1032
            {
1033
                vec_names.push_back(vec_names.at(0));
1034
1035
1036
            }
        }

1037
1038
1039
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1040
        if(name_it != vec_names.end())
1041
1042
1043
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1044

Shucai Xiao's avatar
Shucai Xiao committed
1045
1046
1047
        std::vector<operation> vec_actv_funcs(vec_names.size());
        std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
            return map_actv_funcs[name];
Shucai Xiao's avatar
Shucai Xiao committed
1048
        });
1049
1050
1051
1052
1053
1054
1055
1056

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int linear_before_reset = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1057
        if(contains(attributes, "linear_before_reset"))
1058
1059
1060
1061
        {
            linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1062
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1063
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1064
1065
1066
1067
1068
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1069
1070
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1071
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1072
            std::move(args));
1073
1074

        // second output for last gru output
1075
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1076

Shucai Xiao's avatar
Shucai Xiao committed
1077
        return {hidden_states, last_output};
1078
1079
    }

Shucai Xiao's avatar
Shucai Xiao committed
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
    std::vector<instruction_ref>
    parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1102
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1103
1104
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1105
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1106
1107
1108
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1109
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1110
        }
Shucai Xiao's avatar
Shucai Xiao committed
1111
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1112
        {
Shucai Xiao's avatar
Shucai Xiao committed
1113
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1114
1115
1116
1117
1118
1119
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1120
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
Shucai Xiao's avatar
Shucai Xiao committed
1121
1122
1123
1124
1125
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1126
1127
1128
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1129
1130
1131
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1132
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1133
1134
1135
1136
1137
1138
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1139
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1140
1141
1142
1143
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1144
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1145
1146
1147
1148
1149
1150
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1151
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1152
1153
1154

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1155
1156
1157
1158
1159
1160
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1161
1162
1163
1164
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1165
1166
1167
1168
1169
1170
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1171
1172
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1173
1174
1175
1176
1177
1178
1179
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1180
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1181

Shucai Xiao's avatar
Shucai Xiao committed
1182
1183
1184
1185
1186
1187
1188
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1189
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1190

Shucai Xiao's avatar
Shucai Xiao committed
1191
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1192
1193
1194
1195
1196
1197
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1198
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1199
1200
1201

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1202
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1203
1204
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1205
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1206
1207
1208
            }
        }

1209
1210
1211
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1212
        if(name_it != vec_names.end())
1213
1214
1215
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237

        std::vector<operation> vec_actv_funcs(vec_names.size());
        std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
            return map_actv_funcs[name];
        });

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int input_forget = 0;
        if(contains(attributes, "input_forget"))
        {
            input_forget = parse_value(attributes.at("input_forget")).at<int>();
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1238
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1239
1240
1241
1242
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1243
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1244
1245

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1246
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1247
1248
1249
1250
1251
1252
1253

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }

Paul's avatar
Paul committed
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1266
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1267
1268
1269
1270
1271
1272
        }
    }

    void parse_graph(const onnx::GraphProto& graph)
    {
        nodes = get_nodes(graph);
1273
1274
1275
1276
1277
        std::unordered_map<std::string, onnx::TensorProto> initializer_data;
        for(auto&& f : graph.initializer())
        {
            initializer_data[f.name()] = f;
        }
Paul's avatar
Paul committed
1278
1279
1280
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
            // Does the input have an initializer?
            if(contains(initializer_data, name))
            {
                auto t             = initializer_data[name];
                instructions[name] = prog.add_literal(parse_tensor(t));
            }
            else
            {
                // TODO: Get shape of input parameter
                shape s            = parse_type(input.type());
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1293
        }
Paul's avatar
Paul committed
1294
        for(auto&& output : graph.output())
Paul's avatar
Paul committed
1295
        {
Paul's avatar
Paul committed
1296
            this->parse_node(output.name());
Paul's avatar
Paul committed
1297
1298
1299
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1300
    void parse_undefined(const std::string& name)
1301
    {
Shucai Xiao's avatar
Shucai Xiao committed
1302
        auto ins           = prog.add_instruction(op::undefined{});
1303
1304
1305
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1306
    void parse_node(const std::string& name)
Paul's avatar
Paul committed
1307
    {
Paul's avatar
Paul committed
1308
        if(name.empty())
Paul's avatar
Paul committed
1309
            MIGRAPHX_THROW("Onnx node must have a name");
Paul's avatar
Paul committed
1310
1311
1312
1313
1314
1315
1316
1317
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(nodes.count(input) > 0)
                {
Paul's avatar
Paul committed
1318
1319
                    assert(name != input);
                    this->parse_node(input);
Paul's avatar
Paul committed
1320
                }
Shucai Xiao's avatar
Shucai Xiao committed
1321
                else if(input.empty())
Paul's avatar
Paul committed
1322
                {
1323
                    this->parse_undefined(input);
Paul's avatar
Paul committed
1324
                }
1325
                args.push_back(instructions.at(input));
Paul's avatar
Paul committed
1326
            }
Paul's avatar
Paul committed
1327
            std::vector<instruction_ref> result;
Paul's avatar
Paul committed
1328
1329
            if(ops.count(node.op_type()) == 0)
            {
1330
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
Paul's avatar
Paul committed
1331
1332
1333
            }
            else
            {
Paul's avatar
Paul committed
1334
                result = ops[node.op_type()](get_attributes(node), args);
Paul's avatar
Paul committed
1335
            }
Paul's avatar
Paul committed
1336
            // Even no output nodes produce output in migraphx
Paul's avatar
Paul committed
1337
            if(node.output().empty() and result.size() == 1)
Paul's avatar
Paul committed
1338
1339
            {
                instructions[name] = result.front();
Paul's avatar
Paul committed
1340
1341
1342
            }
            else
            {
Paul's avatar
Paul committed
1343
1344
1345
1346
1347
1348
                assert(node.output().size() >= result.size());
                std::transform(result.begin(),
                               result.end(),
                               node.output().begin(),
                               std::inserter(instructions, instructions.end()),
                               [](auto&& x, auto&& y) { return std::make_pair(y, x); });
Paul's avatar
Paul committed
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
            }
        }
    }

    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    static node_map get_nodes(const onnx::GraphProto& graph)
    {
        std::unordered_map<std::string, onnx::NodeProto> result;
Paul's avatar
Paul committed
1366
        std::size_t n = 0;
Paul's avatar
Paul committed
1367
1368
        for(auto&& node : graph.node())
        {
Paul's avatar
Paul committed
1369
            if(node.output().empty())
Paul's avatar
Paul committed
1370
            {
Paul's avatar
Paul committed
1371
                if(node.name().empty())
Paul's avatar
Paul committed
1372
1373
1374
1375
1376
1377
1378
1379
1380
                {
                    result["migraphx_unamed_node_" + std::to_string(n)] = node;
                    n++;
                }
                else
                {
                    result[node.name()] = node;
                }
            }
Paul's avatar
Paul committed
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
            for(auto&& output : node.output())
            {
                result[output] = node;
            }
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::UNDEFINED: return {};
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::STRING: return {};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
        case onnx::AttributeProto::GRAPH: return {};
Paul's avatar
Paul committed
1406
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1407
1408
1409
1410
1411
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
        case onnx::AttributeProto::STRINGS: return {};
        case onnx::AttributeProto::TENSORS: return {};
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
1412
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
1413
1414
1415
1416
1417
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
1418
1419
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
1420
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
1421
1422
1423
            switch(t.data_type())
            {
            case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
1424
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Scott Thornton's avatar
Scott Thornton committed
1425
            case onnx::TensorProto::UINT8: throw std::runtime_error("");
1426
            case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
Khalique's avatar
Khalique committed
1427
1428
            case onnx::TensorProto::UINT16:
                return create_literal(shape::int32_type, dims, s.data());
1429
1430
1431
            case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
            case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Scott Thornton's avatar
Scott Thornton committed
1432
            case onnx::TensorProto::STRING: throw std::runtime_error("");
1433
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Khalique's avatar
Khalique committed
1434
1435
1436
1437
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
Scott Thornton's avatar
Scott Thornton committed
1438
1439
1440
1441
1442
            case onnx::TensorProto::UINT32: throw std::runtime_error("");
            case onnx::TensorProto::UINT64: throw std::runtime_error("");
            case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
1443
            MIGRAPHX_THROW("Invalid tensor type");
1444
        }
Paul's avatar
Paul committed
1445
1446
1447
1448
        switch(t.data_type())
        {
        case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
        case onnx::TensorProto::FLOAT:
Khalique's avatar
Khalique committed
1449
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
1450
1451
        case onnx::TensorProto::UINT8: throw std::runtime_error("");
        case onnx::TensorProto::INT8:
Khalique's avatar
Khalique committed
1452
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1453
        case onnx::TensorProto::UINT16:
Khalique's avatar
Khalique committed
1454
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1455
        case onnx::TensorProto::INT16:
Khalique's avatar
Khalique committed
1456
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1457
        case onnx::TensorProto::INT32:
Khalique's avatar
Khalique committed
1458
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1459
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
1460
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
1461
1462
        case onnx::TensorProto::STRING: throw std::runtime_error("");
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
1463
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1464
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
1465
        {
Khalique's avatar
Khalique committed
1466
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
1467
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
1468
1469
1470
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
1471
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1472
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1473
        }
Paul's avatar
Paul committed
1474
        case onnx::TensorProto::DOUBLE:
Khalique's avatar
Khalique committed
1475
            return create_literal(shape::double_type, dims, t.double_data());
Paul's avatar
Paul committed
1476
1477
1478
1479
1480
        case onnx::TensorProto::UINT32: throw std::runtime_error("");
        case onnx::TensorProto::UINT64: throw std::runtime_error("");
        case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
1481
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
1482
1483
    }

Khalique's avatar
Khalique committed
1484
    static literal
1485
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
1486
    {
Khalique's avatar
Khalique committed
1487
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
1488
        if(dims.empty())
1489
            return literal{{shape_type}, data};
1490
1491
1492
        return literal{{shape_type, dims}, data};
    }

1493
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
1494
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
1495
1496
    {
        if(dims.empty())
1497
            return literal{{shape_type}, data.begin(), data.end()};
1498
        return literal{{shape_type, dims}, data.begin(), data.end()};
1499
1500
    }

Paul's avatar
Paul committed
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
    static shape parse_type(const onnx::TypeProto& t)
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::UNDEFINED:
            break; // throw std::runtime_error("Unsupported type UNDEFINED");
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::UINT8:
            break; // throw std::runtime_error("Unsupported type UINT8");
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
        case onnx::TensorProto::STRING:
            break; // throw std::runtime_error("Unsupported type STRING");
        case onnx::TensorProto::BOOL:
            break; // throw std::runtime_error("Unsupported type BOOL");
Paul's avatar
Paul committed
1520
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
1521
1522
1523
1524
1525
1526
1527
1528
1529
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
        case onnx::TensorProto::COMPLEX64:
            break; // throw std::runtime_error("Unsupported type COMPLEX64");
        case onnx::TensorProto::COMPLEX128:
            break; // throw std::runtime_error("Unsupported type COMPLEX128");
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
1530
        auto&& tensor_dims = t.tensor_type().shape().dim();
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
                       [](auto&& d) -> std::size_t {
                           if(not d.has_dim_value())
                           {
                               long default_batch_size = 1; // FIXME
                               return default_batch_size;
                           }
                           return d.dim_value();
                       });
Paul's avatar
Paul committed
1542
1543
        return {shape_type, dims};
    }
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Paul's avatar
Paul committed
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
};

program parse_onnx(const std::string& name)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    onnx_parser parser;
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
    return std::move(parser.prog);
}

Paul's avatar
Paul committed
1589
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1590
} // namespace migraphx