onnx.cpp 58.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
Paul's avatar
Paul committed
19
20

namespace migraphx {
Paul's avatar
Paul committed
21
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
22
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
    using node_map      = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
27
28
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
29
30
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
Scott Thornton's avatar
Scott Thornton committed
31
    program prog    = program();
32
    bool is_pytorch = false;
Paul's avatar
Paul committed
33
34

    std::unordered_map<std::string, op_func> ops;
35
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
36
37
38

    onnx_parser()
    {
Khalique's avatar
Khalique committed
39
        add_generic_op("Relu", op::relu{});
Khalique's avatar
Khalique committed
40
41
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Abs", op::abs{});
Shucai Xiao's avatar
Shucai Xiao committed
42
43
        add_generic_op("Exp", op::exp{});
        add_generic_op("Log", op::log{});
Khalique's avatar
Khalique committed
44
45
        // disable dropout for inference
        add_generic_op("Dropout", op::identity{});
Khalique's avatar
Khalique committed
46
        add_generic_op("Identity", op::identity{});
Shucai Xiao's avatar
Shucai Xiao committed
47
48
49
        add_generic_op("Sin", op::sin{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Tan", op::tan{});
50
51
        add_generic_op("Sinh", op::sinh{});
        add_generic_op("Cosh", op::cosh{});
52
        add_generic_op("Tanh", op::tanh{});
53
54
55
        add_generic_op("Asin", op::asin{});
        add_generic_op("Acos", op::acos{});
        add_generic_op("Atan", op::atan{});
Paul's avatar
Paul committed
56

Khalique's avatar
Khalique committed
57
58
59
60
61
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
        add_binary_op("Sub", op::sub{});

Khalique's avatar
Khalique committed
62
63
64
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
65

66
67
        add_mem_op("ArgMax", &onnx_parser::parse_argmax);
        add_mem_op("ArgMin", &onnx_parser::parse_argmin);
Khalique's avatar
Khalique committed
68
        add_mem_op("Clip", &onnx_parser::parse_clip);
Khalique's avatar
Khalique committed
69
        add_mem_op("LRN", &onnx_parser::parse_lrn);
Khalique's avatar
Khalique committed
70
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
71
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
Khalique's avatar
Khalique committed
72
        add_mem_op("Elu", &onnx_parser::parse_elu);
Paul's avatar
Paul committed
73
74
        add_mem_op("Constant", &onnx_parser::parse_constant);
        add_mem_op("Conv", &onnx_parser::parse_conv);
Paul's avatar
Paul committed
75
76
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
77
78
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
Paul's avatar
Paul committed
79
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
Paul's avatar
Paul committed
80
81
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
82
        add_mem_op("MatMul", &onnx_parser::parse_matmul);
83
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
Paul's avatar
Paul committed
84
        add_mem_op("Softmax", &onnx_parser::parse_softmax);
Shucai Xiao's avatar
Shucai Xiao committed
85
        add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax);
86
87
88
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("Slice", &onnx_parser::parse_slice);
Scott Thornton's avatar
Scott Thornton committed
89
        add_mem_op("Concat", &onnx_parser::parse_concat);
90
91
92
        add_mem_op("Gather", &onnx_parser::parse_gather);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
Khalique's avatar
Khalique committed
93
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
Shucai Xiao's avatar
Shucai Xiao committed
94
        add_mem_op("RNN", &onnx_parser::parse_rnn);
95
        add_mem_op("GRU", &onnx_parser::parse_gru);
Shucai Xiao's avatar
Shucai Xiao committed
96
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
Khalique's avatar
Khalique committed
97
        add_mem_op("Pad", &onnx_parser::parse_pad);
98
99
100
101
102
103
104

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
105
106
107
108
109
110
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
111
112
113
114
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
115
116
117
118
119
120
121
122
123
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
124
125
126
127
128
129
130
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
131
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
132
133
134
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
135

136
    template <class T>
Khalique's avatar
Khalique committed
137
    void add_binary_op(std::string name, T x)
138
    {
Paul's avatar
Paul committed
139
        add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
140
            if(args.size() != 2)
Paul's avatar
Paul committed
141
                MIGRAPHX_THROW("binary operators should have 2 operands");
142
            if(contains(attributes, "broadcast") and contains(attributes, "axis"))
143
144
145
146
            {
                uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
                if(broadcasted != 0)
                {
147
                    uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
148
149
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
150
151
                    return prog.add_instruction(x, args[0], l);
                }
152
                return prog.add_instruction(x, args);
153
            }
Paul's avatar
Paul committed
154
            else
155
            {
Khalique's avatar
Khalique committed
156
                return add_broadcastable_binary_op(args[0], args[1], x);
157
158
159
160
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
161
162
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
163
164
165
166
167
168
169
170
171
172
173
174
175
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
176
        if(s0.size() > s1.size())
177
178
179
180
181
182
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
183
184
185
186
187
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
                       [](auto a, auto b) { return std::max(a, b); });
188
189
190
191

        return out_lens;
    }

Khalique's avatar
Khalique committed
192
193
194
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
195
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
196
197
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
198
199
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
200
            auto out_lens = compute_broadcasted_lens(s0, s1);
Shucai Xiao's avatar
Shucai Xiao committed
201
202
            auto l0       = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
            auto l1       = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
Khalique's avatar
Khalique committed
203
204
205
206
207
208
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
209
210
    }

Paul's avatar
Paul committed
211
    template <class T>
Paul's avatar
Paul committed
212
213
    void add_generic_op(std::string name, T x)
    {
Paul's avatar
Paul committed
214
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
215
216
217
218
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
219
    template <class T>
Khalique's avatar
Khalique committed
220
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
221
    {
Paul's avatar
Paul committed
222
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
223
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
224
225
226
227
228
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
229
        });
Khalique's avatar
Khalique committed
230
231
    }

Khalique's avatar
Khalique committed
232
233
234
    instruction_ref parse_clip(const std::string&,
                               const attribute_map& attributes,
                               std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
235
236
237
238
239
240
241
242
243
244
245
246
247
    {
        op::clip op;
        if(contains(attributes, "max"))
        {
            op.max_val = parse_value(attributes.at("max")).at<float>();
        }
        if(contains(attributes, "min"))
        {
            op.min_val = parse_value(attributes.at("min")).at<float>();
        }
        return prog.add_instruction(op, std::move(args));
    }

Paul's avatar
Paul committed
248
    instruction_ref
Paul's avatar
Paul committed
249
    parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
250
251
    {
        auto dims = args.front()->get_shape().lens();
Scott Thornton's avatar
Scott Thornton committed
252
253
        auto r =
            prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
254
255
        auto s = prog.add_instruction(op::softmax{}, r);
        return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
Paul's avatar
Paul committed
256
257
    }

Shucai Xiao's avatar
Shucai Xiao committed
258
259
260
    instruction_ref parse_logsoftmax(const std::string&,
                                     const attribute_map& attributes,
                                     std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
261
262
263
264
265
266
267
268
269
270
    {
        int axis = 1;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

        return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
    }

271
    instruction_ref parse_argmax(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
272
273
                                 const attribute_map& attributes,
                                 std::vector<instruction_ref> args)
274
275
276
277
278
279
280
    {
        int axis = 0;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
281
        int keep_dims = 1;
Shucai Xiao's avatar
Shucai Xiao committed
282
        if(contains(attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
283
284
285
286
287
        {
            keep_dims = parse_value(attributes.at("keepdims")).at<int>();
        }

        return prog.add_instruction(op::argmax{axis, keep_dims}, std::move(args));
288
289
290
    }

    instruction_ref parse_argmin(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
291
292
                                 const attribute_map& attributes,
                                 std::vector<instruction_ref> args)
293
294
295
296
297
298
299
    {
        int axis = 0;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
300
        int keep_dims = 1;
Shucai Xiao's avatar
Shucai Xiao committed
301
        if(contains(attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
302
303
304
305
306
        {
            keep_dims = parse_value(attributes.at("keepdims")).at<int>();
        }

        return prog.add_instruction(op::argmin{axis, keep_dims}, std::move(args));
307
308
    }

Paul's avatar
Paul committed
309
    instruction_ref
Paul's avatar
Paul committed
310
    parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
311
    {
312
        op::convolution op;
313
        auto l0 = args[0];
Paul's avatar
Paul committed
314
315
        if(contains(attributes, "pads"))
        {
Scott Thornton's avatar
Scott Thornton committed
316
            if(contains(attributes, "auto_pad"))
317
            {
Paul's avatar
Paul committed
318
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
319
            }
320
321
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
322
            if(padding.size() != 4)
323
            {
Paul's avatar
Paul committed
324
                MIGRAPHX_THROW("padding should have 4 values");
325
            }
Scott Thornton's avatar
Scott Thornton committed
326
            if(padding[0] != padding[2] || padding[1] != padding[3])
327
            {
328
329
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
Khalique's avatar
Khalique committed
330
                l0      = prog.add_instruction(op::pad{padding}, l0);
331
            }
332
333
334
335
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
336
            }
Paul's avatar
Paul committed
337
        }
Paul's avatar
Paul committed
338
339
340
341
342
343
344
345
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "dilations"))
        {
            copy(attributes["dilations"].ints(), op.dilation.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
346
        if(contains(attributes, "auto_pad"))
347
348
        {
            auto s = attributes["auto_pad"].s();
Scott Thornton's avatar
Scott Thornton committed
349
            if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
350
            {
Paul's avatar
Paul committed
351
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
352
353
            }

wsttiger's avatar
fixes  
wsttiger committed
354
            if(s.find("SAME") != std::string::npos)
355
            {
356
                op.padding_mode = op::padding_mode_t::same;
357
358
            }
        }
Khalique's avatar
Khalique committed
359
360
361
362
        if(contains(attributes, "group"))
        {
            op.group = parse_value(attributes.at("group")).at<int>();
        }
Paul's avatar
Paul committed
363
364
365
366
        if(args.size() == 3)
        {
            uint64_t axis = 1;
            auto l1       = prog.add_instruction(op, args[0], args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
367
            auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
368
            return prog.add_instruction(op::add{}, l1, l2);
Paul's avatar
Paul committed
369
        }
370
        return prog.add_instruction(op, l0, args[1]);
Paul's avatar
Paul committed
371
    }
Paul's avatar
Paul committed
372

Paul's avatar
Paul committed
373
374
375
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
                                  std::vector<instruction_ref> args)
Paul's avatar
Paul committed
376
    {
Khalique's avatar
Khalique committed
377
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
378
        auto l0 = args[0];
Khalique's avatar
Khalique committed
379
        if(starts_with(name, "Global"))
380
        {
Khalique's avatar
Khalique committed
381
382
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
383
        }
Paul's avatar
Paul committed
384
385
        if(contains(attributes, "pads"))
        {
386
387
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
388
            if(padding.size() != 4)
389
            {
Paul's avatar
Paul committed
390
                MIGRAPHX_THROW("padding should have 4 values");
391
            }
Scott Thornton's avatar
Scott Thornton committed
392
            if(padding[0] != padding[2] || padding[1] != padding[3])
393
            {
394
395
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
Khalique's avatar
Khalique committed
396
397
                l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
                                          l0);
398
399
400
401
402
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
403
            }
Paul's avatar
Paul committed
404
405
406
407
408
409
410
411
412
        }
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "kernel_shape"))
        {
            copy(attributes["kernel_shape"].ints(), op.lengths.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
413
        if(contains(attributes, "auto_pad"))
414
415
        {
            auto s = attributes["auto_pad"].s();
416
            if(s.find("SAME_UPPER") == std::string::npos)
417
            {
418
                MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling");
419
            }
420
            op.padding_mode = op::padding_mode_t::same;
421
422
        }

423
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
424
425
    }

Paul's avatar
Paul committed
426
    instruction_ref
Paul's avatar
Paul committed
427
    parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
428
    {
429
        op::reshape op;
Paul's avatar
Paul committed
430
431
432
433
434
435
436
        if(args.size() == 1)
        {
            literal s = parse_value(attributes.at("shape"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
437
            auto s = args[1]->eval();
Paul's avatar
Paul committed
438
            if(s.empty())
Paul's avatar
Paul committed
439
                MIGRAPHX_THROW("Dynamic shape is not supported.");
Paul's avatar
Paul committed
440
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
441
        }
Paul's avatar
Paul committed
442
443
444
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
445
    instruction_ref
Paul's avatar
Paul committed
446
    parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
447
    {
448
        uint64_t axis = 1;
Paul's avatar
Paul committed
449
450
451
452
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
453
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
454
455
    }

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    instruction_ref
    parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::squeeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

    instruction_ref
    parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::unsqueeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

Scott Thornton's avatar
Scott Thornton committed
474
475
476
477
478
479
480
    instruction_ref
    parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::size_t axis = parse_value(attributes.at("axis")).at<int>();
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
481

482
483
484
    instruction_ref
    parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
485
        int axis = 0;
486
487
488
489
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
490
        op::gather op{axis};
491
492
493
        return prog.add_instruction(op, std::move(args));
    }

494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    instruction_ref
    parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::slice op;
        if(contains(attributes, "axes"))
        {
            literal s = parse_value(attributes.at("axes"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
        {
            literal s = parse_value(attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
        }
        {
            literal s = parse_value(attributes.at("starts"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
514
515
516
    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
                                   const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
517
    {
Shucai Xiao's avatar
Shucai Xiao committed
518
        literal v     = parse_value(attributes.at("value"));
519
520
521
        auto dim_size = attributes.at("value").t().dims_size();
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
522
        {
523
            migraphx::shape scalar_shape{v.get_shape().type()};
524
525
526
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
527
528
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
529

Paul's avatar
Paul committed
530
    instruction_ref
Paul's avatar
Paul committed
531
    parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
532
533
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
534
        float beta  = 1.0f;
Paul's avatar
Paul committed
535
536
537
538
539
540
541
542
        bool transa = false;
        bool transb = false;
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        if(contains(attributes, "beta"))
        {
543
            beta = parse_value(attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
544
545
546
547
548
549
550
551
552
        }
        if(contains(attributes, "transA"))
        {
            transa = parse_value(attributes.at("transA")).at<bool>();
        }
        if(contains(attributes, "transB"))
        {
            transb = parse_value(attributes.at("transB")).at<bool>();
        }
553
554
555
556
557
558

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

559
560
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
561
562
        if(args.size() == 3)
        {
563
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
564
            {
Shucai Xiao's avatar
Shucai Xiao committed
565
                auto out_lens   = l1->get_shape().lens();
566
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
567
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
568
569
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
570
                {
571
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
572
                }
573
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
574
            }
Paul's avatar
Paul committed
575
        }
576
577

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
578
579
    }

580
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
581
    parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
582
    {
Shucai Xiao's avatar
Shucai Xiao committed
583
584
        auto l0      = args[0];
        auto l1      = args[1];
585
586
587
588
589
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
590
        if(l0_lens.size() == 1)
591
592
593
594
595
596
597
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
598
        if(l1_lens.size() == 1)
599
600
601
602
603
604
605
606
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
607
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
608
609
610
611
612
613
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
614
            l0_broadcasted_lens = output_lens;
615
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
616
            l1_broadcasted_lens = output_lens;
617
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
618
            if(l0_lens != l0_broadcasted_lens)
619
620
621
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
622
            if(l1_lens != l1_broadcasted_lens)
623
624
625
626
627
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
628
        auto dot_res     = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
629
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
630
        if(is_a_prepended)
631
632
633
634
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
635
        if(is_b_appended)
636
637
638
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
639

640
641
642
        return dot_res;
    }

643
    instruction_ref
Paul's avatar
Paul committed
644
    parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
645
    {
Scott Thornton's avatar
Scott Thornton committed
646
647
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
648
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
Scott Thornton's avatar
Scott Thornton committed
649
        bool is_test                                      = false;
650
651
652
653
654
655
        if(contains(attributes, "epsilon"))
        {
            epsilon = parse_value(attributes.at("epsilon")).at<float>();
        }
        if(contains(attributes, "momentum"))
        {
656
            momentum = parse_value(attributes.at("momentum")).at<float>();
657
658
659
        }
        if(contains(attributes, "is_test"))
        {
wsttiger's avatar
wsttiger committed
660
            is_test = parse_value(attributes.at("is_test")).at<uint64_t>() > 0;
661
662
663
        }
        if(contains(attributes, "spatial"))
        {
664
            bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
665
666
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
667
        }
Paul's avatar
Paul committed
668
        (void)is_test;
Paul's avatar
Paul committed
669
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
670
        return prog.add_instruction(op, std::move(args));
671
672
    }

673
674
675
676
    instruction_ref parse_leaky_relu(const std::string&,
                                     attribute_map attributes,
                                     std::vector<instruction_ref> args)
    {
Khalique's avatar
Khalique committed
677
        float alpha = 0.01; // default alpha val for leaky relu
678
679
680
681
682
683
684
685
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
686
687
    instruction_ref
    parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
688
689
690
691
692
693
694
695
696
697
    {
        float alpha = 1.0; // default alpha val for elu
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
698
699
    instruction_ref
    parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
700
701
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
702
703
704
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
Khalique's avatar
Khalique committed
705
706
707
708
709
710
711
712
713
714
715
716
        if(contains(attributes, "alpha"))
            alpha = parse_value(attributes.at("alpha")).at<float>();
        if(contains(attributes, "beta"))
            beta = parse_value(attributes.at("beta")).at<float>();
        if(contains(attributes, "bias"))
            bias = parse_value(attributes.at("bias")).at<float>();
        if(contains(attributes, "size"))
            size = parse_value(attributes.at("size")).at<int>();
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
    instruction_ref parse_imagescaler(const std::string&,
                                      attribute_map attributes,
                                      std::vector<instruction_ref> args)
    {
        float scale = 1.0;
        std::vector<float> bias{};
        if(contains(attributes, "scale"))
        {
            scale = parse_value(attributes.at("scale")).at<float>();
        }

        if(contains(attributes, "bias"))
        {
            auto&& bias_floats = attributes["bias"].floats();
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
733
        auto input_lens = args.front()->get_shape().lens();
Khalique's avatar
Khalique committed
734

Khalique's avatar
Khalique committed
735
736
        auto scale_val = prog.add_literal(scale);
        auto bias_vals = prog.add_literal(
Paul's avatar
Paul committed
737
            migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
738

739
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
740
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
741
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
742
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
743
    }
Khalique's avatar
Khalique committed
744

Khalique's avatar
Khalique committed
745
746
    instruction_ref
    parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
747
748
749
750
751
752
753
    {
        std::vector<int64_t> perm{};
        if(contains(attributes, "perm"))
        {
            auto&& perm_vals = attributes["perm"].ints();
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
754
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
755
756
    }

Khalique's avatar
Khalique committed
757
758
759
760
761
762
763
764
765
766
    instruction_ref
    parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
        if(contains(attributes, "pads"))
        {
            auto&& pad_vals = attributes["pads"].ints();
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
767
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
768
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
769
770
771
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
Khalique's avatar
Khalique committed
772
773
774
775
776
777
778
779
780
781
782
783
        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }
        if(contains(attributes, "mode"))
        {
            auto mode = attributes.at("mode").s();
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
784
785
786
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
787
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
788
789
    {
        if(args.size() != 1)
790
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
    instruction_ref parse_constant_fill(const std::string&,
                                        attribute_map attributes,
                                        std::vector<instruction_ref> args)
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

        if(contains(attributes, "dtype"))
        {
            dtype = parse_value(attributes.at("dtype")).at<int>();
        }
        migraphx::shape::type_t type = get_type(dtype);

        if(contains(attributes, "input_as_shape"))
        {
            input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
        }

        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
827
828
        if(contains(attributes, "extra_shape"))
        {
829
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
830
831
        }

832
833
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
834
            if(args.size() != 1)
835
            {
836
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
837
838
            }

Shucai Xiao's avatar
Shucai Xiao committed
839
840
            if(contains(attributes, "shape"))
            {
841
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
842
                               "at the same time");
843
844
            }

845
846
847
            migraphx::argument in = args[0]->eval();
            if(in.empty())
            {
848
                MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
849
            }
850

851
852
853
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
854
855
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
856
857
858
        }
        else if(input_as_shape == 0)
        {
Shucai Xiao's avatar
Shucai Xiao committed
859
860
            if(!contains(attributes, "shape"))
            {
861
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
862
863
864
            }

            literal ls = parse_value(attributes.at("shape"));
865
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
866
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
867
            migraphx::shape s{type, dims};
868
869
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
870
871
872
        }
        else
        {
873
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
874
875
876
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
877
    std::vector<instruction_ref>
Shucai Xiao's avatar
Shucai Xiao committed
878
879
880
    parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
881
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
882
883
884

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
885
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
886
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
887
888
889
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
890
891
892
893
894
895
896
897
898
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

899
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
900
901
        if(direction == "bidirectional")
        {
902
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
903
904
905
        }
        else if(direction == "reverse")
        {
906
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
907
908
        }

909
        std::vector<std::string> vec_names{"tanh"};
910
911
912
913
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
914
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
915
916
917
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
918
919
        }

920
921
922
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
923
        if(name_it != vec_names.end())
924
925
926
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
927

Shucai Xiao's avatar
Shucai Xiao committed
928
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
929
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
930
        // if only one actv function is provided, we use it in both
931
        // forward and reverse direction
932
        if(dirct == op::rnn_direction::bidirectional)
933
        {
Shucai Xiao's avatar
Shucai Xiao committed
934
            if(vec_names.size() == 1)
935
936
937
938
939
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
940
941
942
        std::vector<operation> vec_actv_funcs(vec_names.size());
        std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& fn) {
            return map_actv_funcs[fn];
943
        });
Shucai Xiao's avatar
Shucai Xiao committed
944

Shucai Xiao's avatar
Shucai Xiao committed
945
946
947
948
949
950
951
        // To be added later
        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

952
953
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
954
        if(args.size() < 6)
955
956
957
958
959
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
960
961
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
962
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
963

964
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
965
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
966

Shucai Xiao's avatar
Shucai Xiao committed
967
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
968
969
    }

970
    std::vector<instruction_ref>
971
972
973
974
975
976
977
    parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
978
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
979
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
980
981
982
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
983
984
985
986
987
988
989
990
991
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

992
        op::rnn_direction dirct = op::rnn_direction::forward;
993
994
        if(direction == "bidirectional")
        {
995
            dirct = op::rnn_direction::bidirectional;
996
997
998
        }
        else if(direction == "reverse")
        {
999
            dirct = op::rnn_direction::reverse;
1000
1001
        }

1002
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1003
1004
        if(contains(attributes, "activations"))
        {
1005
            auto names = attributes.at("activations").strings();
1006
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1007
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1008
1009
1010
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1011
1012
        }

1013
        // need 4 activation functions
1014
        if(dirct == op::rnn_direction::bidirectional)
1015
        {
Shucai Xiao's avatar
Shucai Xiao committed
1016
            // 4 activation functions are used in the bidirectional
1017
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1018
1019
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1020
1021
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1022
1023
1024
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1025
            if(vec_names.size() == 1)
1026
            {
1027
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1028
            }
1029
            else if(vec_names.size() == 2)
1030
            {
1031
1032
1033
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1034
            }
1035
            else if(vec_names.size() == 3)
1036
            {
1037
                vec_names.push_back(vec_names.at(2));
1038
1039
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1040
        else
1041
        {
1042
            if(vec_names.size() == 1)
1043
            {
1044
                vec_names.push_back(vec_names.at(0));
1045
1046
1047
            }
        }

1048
1049
1050
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1051
        if(name_it != vec_names.end())
1052
1053
1054
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1055

Shucai Xiao's avatar
Shucai Xiao committed
1056
1057
1058
        std::vector<operation> vec_actv_funcs(vec_names.size());
        std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
            return map_actv_funcs[name];
Shucai Xiao's avatar
Shucai Xiao committed
1059
        });
1060
1061
1062
1063
1064
1065
1066
1067

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int linear_before_reset = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1068
        if(contains(attributes, "linear_before_reset"))
1069
1070
1071
1072
        {
            linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1073
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1074
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1075
1076
1077
1078
1079
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1080
1081
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1082
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1083
            std::move(args));
1084
1085

        // second output for last gru output
1086
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1087

Shucai Xiao's avatar
Shucai Xiao committed
1088
        return {hidden_states, last_output};
1089
1090
    }

Shucai Xiao's avatar
Shucai Xiao committed
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
    std::vector<instruction_ref>
    parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1113
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1114
1115
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1116
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1117
1118
1119
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1120
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1121
        }
Shucai Xiao's avatar
Shucai Xiao committed
1122
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1123
        {
Shucai Xiao's avatar
Shucai Xiao committed
1124
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1125
1126
1127
1128
1129
1130
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1131
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
Shucai Xiao's avatar
Shucai Xiao committed
1132
1133
1134
1135
1136
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1137
1138
1139
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1140
1141
1142
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1143
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1144
1145
1146
1147
1148
1149
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1150
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1151
1152
1153
1154
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1155
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1156
1157
1158
1159
1160
1161
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1162
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1163
1164
1165

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1166
1167
1168
1169
1170
1171
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1172
1173
1174
1175
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1176
1177
1178
1179
1180
1181
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1182
1183
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1184
1185
1186
1187
1188
1189
1190
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1191
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1192

Shucai Xiao's avatar
Shucai Xiao committed
1193
1194
1195
1196
1197
1198
1199
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1200
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1201

Shucai Xiao's avatar
Shucai Xiao committed
1202
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1203
1204
1205
1206
1207
1208
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1209
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1210
1211
1212

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1213
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1214
1215
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1216
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1217
1218
1219
            }
        }

1220
1221
1222
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1223
        if(name_it != vec_names.end())
1224
1225
1226
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248

        std::vector<operation> vec_actv_funcs(vec_names.size());
        std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
            return map_actv_funcs[name];
        });

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int input_forget = 0;
        if(contains(attributes, "input_forget"))
        {
            input_forget = parse_value(attributes.at("input_forget")).at<int>();
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1249
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1250
1251
1252
1253
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1254
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1255
1256

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1257
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1258
1259
1260
1261
1262
1263
1264

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }

Paul's avatar
Paul committed
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1277
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1278
1279
1280
1281
1282
1283
        }
    }

    void parse_graph(const onnx::GraphProto& graph)
    {
        nodes = get_nodes(graph);
1284
1285
1286
1287
1288
        std::unordered_map<std::string, onnx::TensorProto> initializer_data;
        for(auto&& f : graph.initializer())
        {
            initializer_data[f.name()] = f;
        }
Paul's avatar
Paul committed
1289
1290
1291
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
            // Does the input have an initializer?
            if(contains(initializer_data, name))
            {
                auto t             = initializer_data[name];
                instructions[name] = prog.add_literal(parse_tensor(t));
            }
            else
            {
                // TODO: Get shape of input parameter
                shape s            = parse_type(input.type());
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1304
        }
Paul's avatar
Paul committed
1305
        for(auto&& output : graph.output())
Paul's avatar
Paul committed
1306
        {
Paul's avatar
Paul committed
1307
            this->parse_node(output.name());
Paul's avatar
Paul committed
1308
1309
1310
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1311
    void parse_undefined(const std::string& name)
1312
    {
Shucai Xiao's avatar
Shucai Xiao committed
1313
        auto ins           = prog.add_instruction(op::undefined{});
1314
1315
1316
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1317
    void parse_node(const std::string& name)
Paul's avatar
Paul committed
1318
    {
Paul's avatar
Paul committed
1319
        if(name.empty())
Paul's avatar
Paul committed
1320
            MIGRAPHX_THROW("Onnx node must have a name");
Paul's avatar
Paul committed
1321
1322
1323
1324
1325
1326
1327
1328
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(nodes.count(input) > 0)
                {
Paul's avatar
Paul committed
1329
1330
                    assert(name != input);
                    this->parse_node(input);
Paul's avatar
Paul committed
1331
                }
Shucai Xiao's avatar
Shucai Xiao committed
1332
                else if(input.empty())
Paul's avatar
Paul committed
1333
                {
1334
                    this->parse_undefined(input);
Paul's avatar
Paul committed
1335
                }
1336
                args.push_back(instructions.at(input));
Paul's avatar
Paul committed
1337
            }
Paul's avatar
Paul committed
1338
            std::vector<instruction_ref> result;
Paul's avatar
Paul committed
1339
1340
            if(ops.count(node.op_type()) == 0)
            {
1341
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
Paul's avatar
Paul committed
1342
1343
1344
            }
            else
            {
Paul's avatar
Paul committed
1345
                result = ops[node.op_type()](get_attributes(node), args);
Paul's avatar
Paul committed
1346
            }
Paul's avatar
Paul committed
1347
            // Even no output nodes produce output in migraphx
Paul's avatar
Paul committed
1348
            if(node.output().empty() and result.size() == 1)
Paul's avatar
Paul committed
1349
1350
            {
                instructions[name] = result.front();
Paul's avatar
Paul committed
1351
1352
1353
            }
            else
            {
Paul's avatar
Paul committed
1354
1355
1356
1357
1358
1359
                assert(node.output().size() >= result.size());
                std::transform(result.begin(),
                               result.end(),
                               node.output().begin(),
                               std::inserter(instructions, instructions.end()),
                               [](auto&& x, auto&& y) { return std::make_pair(y, x); });
Paul's avatar
Paul committed
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
            }
        }
    }

    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    static node_map get_nodes(const onnx::GraphProto& graph)
    {
        std::unordered_map<std::string, onnx::NodeProto> result;
Paul's avatar
Paul committed
1377
        std::size_t n = 0;
Paul's avatar
Paul committed
1378
1379
        for(auto&& node : graph.node())
        {
Paul's avatar
Paul committed
1380
            if(node.output().empty())
Paul's avatar
Paul committed
1381
            {
Paul's avatar
Paul committed
1382
                if(node.name().empty())
Paul's avatar
Paul committed
1383
1384
1385
1386
1387
1388
1389
1390
1391
                {
                    result["migraphx_unamed_node_" + std::to_string(n)] = node;
                    n++;
                }
                else
                {
                    result[node.name()] = node;
                }
            }
Paul's avatar
Paul committed
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
            for(auto&& output : node.output())
            {
                result[output] = node;
            }
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::UNDEFINED: return {};
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::STRING: return {};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
        case onnx::AttributeProto::GRAPH: return {};
Paul's avatar
Paul committed
1417
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1418
1419
1420
1421
1422
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
        case onnx::AttributeProto::STRINGS: return {};
        case onnx::AttributeProto::TENSORS: return {};
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
1423
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
1424
1425
1426
1427
1428
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
1429
1430
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
1431
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
1432
1433
1434
            switch(t.data_type())
            {
            case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
1435
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Scott Thornton's avatar
Scott Thornton committed
1436
            case onnx::TensorProto::UINT8: throw std::runtime_error("");
1437
            case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
Khalique's avatar
Khalique committed
1438
1439
            case onnx::TensorProto::UINT16:
                return create_literal(shape::int32_type, dims, s.data());
1440
1441
1442
            case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
            case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Scott Thornton's avatar
Scott Thornton committed
1443
            case onnx::TensorProto::STRING: throw std::runtime_error("");
1444
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Khalique's avatar
Khalique committed
1445
1446
1447
1448
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
Scott Thornton's avatar
Scott Thornton committed
1449
1450
1451
1452
1453
            case onnx::TensorProto::UINT32: throw std::runtime_error("");
            case onnx::TensorProto::UINT64: throw std::runtime_error("");
            case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
1454
            MIGRAPHX_THROW("Invalid tensor type");
1455
        }
Paul's avatar
Paul committed
1456
1457
1458
1459
        switch(t.data_type())
        {
        case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
        case onnx::TensorProto::FLOAT:
Khalique's avatar
Khalique committed
1460
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
1461
1462
        case onnx::TensorProto::UINT8: throw std::runtime_error("");
        case onnx::TensorProto::INT8:
Khalique's avatar
Khalique committed
1463
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1464
        case onnx::TensorProto::UINT16:
Khalique's avatar
Khalique committed
1465
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1466
        case onnx::TensorProto::INT16:
Khalique's avatar
Khalique committed
1467
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1468
        case onnx::TensorProto::INT32:
Khalique's avatar
Khalique committed
1469
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1470
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
1471
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
1472
1473
        case onnx::TensorProto::STRING: throw std::runtime_error("");
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
1474
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1475
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
1476
        {
Khalique's avatar
Khalique committed
1477
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
1478
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
1479
1480
1481
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
1482
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1483
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1484
        }
Paul's avatar
Paul committed
1485
        case onnx::TensorProto::DOUBLE:
Khalique's avatar
Khalique committed
1486
            return create_literal(shape::double_type, dims, t.double_data());
Paul's avatar
Paul committed
1487
1488
1489
1490
1491
        case onnx::TensorProto::UINT32: throw std::runtime_error("");
        case onnx::TensorProto::UINT64: throw std::runtime_error("");
        case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
1492
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
1493
1494
    }

Khalique's avatar
Khalique committed
1495
    static literal
1496
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
1497
    {
Khalique's avatar
Khalique committed
1498
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
1499
        if(dims.empty())
1500
            return literal{{shape_type}, data};
1501
1502
1503
        return literal{{shape_type, dims}, data};
    }

1504
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
1505
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
1506
1507
    {
        if(dims.empty())
1508
            return literal{{shape_type}, data.begin(), data.end()};
1509
        return literal{{shape_type, dims}, data.begin(), data.end()};
1510
1511
    }

Paul's avatar
Paul committed
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
    static shape parse_type(const onnx::TypeProto& t)
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::UNDEFINED:
            break; // throw std::runtime_error("Unsupported type UNDEFINED");
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::UINT8:
            break; // throw std::runtime_error("Unsupported type UINT8");
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
        case onnx::TensorProto::STRING:
            break; // throw std::runtime_error("Unsupported type STRING");
        case onnx::TensorProto::BOOL:
            break; // throw std::runtime_error("Unsupported type BOOL");
Paul's avatar
Paul committed
1531
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
1532
1533
1534
1535
1536
1537
1538
1539
1540
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
        case onnx::TensorProto::COMPLEX64:
            break; // throw std::runtime_error("Unsupported type COMPLEX64");
        case onnx::TensorProto::COMPLEX128:
            break; // throw std::runtime_error("Unsupported type COMPLEX128");
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
1541
        auto&& tensor_dims = t.tensor_type().shape().dim();
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
                       [](auto&& d) -> std::size_t {
                           if(not d.has_dim_value())
                           {
                               long default_batch_size = 1; // FIXME
                               return default_batch_size;
                           }
                           return d.dim_value();
                       });
Paul's avatar
Paul committed
1553
1554
        return {shape_type, dims};
    }
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Paul's avatar
Paul committed
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
};

program parse_onnx(const std::string& name)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    onnx_parser parser;
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
    return std::move(parser.prog);
}

Paul's avatar
Paul committed
1600
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1601
} // namespace migraphx