onnx.cpp 60.6 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
Paul's avatar
Paul committed
19
20

namespace migraphx {
Paul's avatar
Paul committed
21
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
22
23
24
25
26

struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
    using node_map      = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
27
28
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
29
30
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
Scott Thornton's avatar
Scott Thornton committed
31
    program prog    = program();
32
    bool is_pytorch = false;
Paul's avatar
Paul committed
33
34

    std::unordered_map<std::string, op_func> ops;
35
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
36
37
38

    onnx_parser()
    {
Khalique's avatar
Khalique committed
39
        add_generic_op("Relu", op::relu{});
Khalique's avatar
Khalique committed
40
41
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Abs", op::abs{});
Shucai Xiao's avatar
Shucai Xiao committed
42
43
        add_generic_op("Exp", op::exp{});
        add_generic_op("Log", op::log{});
Khalique's avatar
Khalique committed
44
45
        // disable dropout for inference
        add_generic_op("Dropout", op::identity{});
Khalique's avatar
Khalique committed
46
        add_generic_op("Identity", op::identity{});
Shucai Xiao's avatar
Shucai Xiao committed
47
48
49
        add_generic_op("Sin", op::sin{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Tan", op::tan{});
50
51
        add_generic_op("Sinh", op::sinh{});
        add_generic_op("Cosh", op::cosh{});
52
        add_generic_op("Tanh", op::tanh{});
53
54
55
        add_generic_op("Asin", op::asin{});
        add_generic_op("Acos", op::acos{});
        add_generic_op("Atan", op::atan{});
Paul's avatar
Paul committed
56

Khalique's avatar
Khalique committed
57
58
59
60
61
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
        add_binary_op("Sub", op::sub{});

Khalique's avatar
Khalique committed
62
63
64
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
65

66
67
        add_mem_op("ArgMax", &onnx_parser::parse_argmax);
        add_mem_op("ArgMin", &onnx_parser::parse_argmin);
Khalique's avatar
Khalique committed
68
        add_mem_op("Clip", &onnx_parser::parse_clip);
Khalique's avatar
Khalique committed
69
        add_mem_op("LRN", &onnx_parser::parse_lrn);
Khalique's avatar
Khalique committed
70
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
71
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
Khalique's avatar
Khalique committed
72
        add_mem_op("Elu", &onnx_parser::parse_elu);
Paul's avatar
Paul committed
73
74
        add_mem_op("Constant", &onnx_parser::parse_constant);
        add_mem_op("Conv", &onnx_parser::parse_conv);
Paul's avatar
Paul committed
75
76
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
77
78
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
Paul's avatar
Paul committed
79
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
Paul's avatar
Paul committed
80
81
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
82
        add_mem_op("MatMul", &onnx_parser::parse_matmul);
83
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
Paul's avatar
Paul committed
84
        add_mem_op("Softmax", &onnx_parser::parse_softmax);
Shucai Xiao's avatar
Shucai Xiao committed
85
        add_mem_op("LogSoftmax", &onnx_parser::parse_logsoftmax);
86
87
88
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
        add_mem_op("Slice", &onnx_parser::parse_slice);
Scott Thornton's avatar
Scott Thornton committed
89
        add_mem_op("Concat", &onnx_parser::parse_concat);
90
91
92
        add_mem_op("Gather", &onnx_parser::parse_gather);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
Khalique's avatar
Khalique committed
93
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
Shucai Xiao's avatar
Shucai Xiao committed
94
        add_mem_op("RNN", &onnx_parser::parse_rnn);
95
        add_mem_op("GRU", &onnx_parser::parse_gru);
Shucai Xiao's avatar
Shucai Xiao committed
96
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
Khalique's avatar
Khalique committed
97
        add_mem_op("Pad", &onnx_parser::parse_pad);
98
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
99
100
101
102
103
104
105

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
106
107
108
109
110
111
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
112
113
114
115
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
116
117
118
119
120
121
122
123
124
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
125
126
127
128
129
130
131
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
132
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
133
134
135
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
136

137
    template <class T>
Khalique's avatar
Khalique committed
138
    void add_binary_op(std::string name, T x)
139
    {
Paul's avatar
Paul committed
140
        add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
141
            if(args.size() != 2)
Paul's avatar
Paul committed
142
                MIGRAPHX_THROW("binary operators should have 2 operands");
143
            if(contains(attributes, "broadcast") and contains(attributes, "axis"))
144
145
146
147
            {
                uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
                if(broadcasted != 0)
                {
148
                    uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
149
150
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
151
152
                    return prog.add_instruction(x, args[0], l);
                }
153
                return prog.add_instruction(x, args);
154
            }
Paul's avatar
Paul committed
155
            else
156
            {
Khalique's avatar
Khalique committed
157
                return add_broadcastable_binary_op(args[0], args[1], x);
158
159
160
161
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
162
163
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
164
165
166
167
168
169
170
171
172
173
174
175
176
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
177
        if(s0.size() > s1.size())
178
179
180
181
182
183
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
184
185
186
187
188
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
                       [](auto a, auto b) { return std::max(a, b); });
189
190
191
192

        return out_lens;
    }

Khalique's avatar
Khalique committed
193
194
195
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
196
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
197
198
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
199
200
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
201
            auto out_lens = compute_broadcasted_lens(s0, s1);
Shucai Xiao's avatar
Shucai Xiao committed
202
203
            auto l0       = prog.add_instruction(op::multibroadcast{out_lens}, arg0);
            auto l1       = prog.add_instruction(op::multibroadcast{out_lens}, arg1);
Khalique's avatar
Khalique committed
204
205
206
207
208
209
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
210
211
    }

Paul's avatar
Paul committed
212
    template <class T>
Paul's avatar
Paul committed
213
214
    void add_generic_op(std::string name, T x)
    {
Paul's avatar
Paul committed
215
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
216
217
218
219
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
220
    template <class T>
Khalique's avatar
Khalique committed
221
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
222
    {
Paul's avatar
Paul committed
223
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
224
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
225
226
227
228
229
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
230
        });
Khalique's avatar
Khalique committed
231
232
    }

Khalique's avatar
Khalique committed
233
234
235
    instruction_ref parse_clip(const std::string&,
                               const attribute_map& attributes,
                               std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
236
237
238
239
240
241
242
243
244
245
246
247
248
    {
        op::clip op;
        if(contains(attributes, "max"))
        {
            op.max_val = parse_value(attributes.at("max")).at<float>();
        }
        if(contains(attributes, "min"))
        {
            op.min_val = parse_value(attributes.at("min")).at<float>();
        }
        return prog.add_instruction(op, std::move(args));
    }

Paul's avatar
Paul committed
249
    instruction_ref
Paul's avatar
Paul committed
250
    parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
251
252
    {
        auto dims = args.front()->get_shape().lens();
Scott Thornton's avatar
Scott Thornton committed
253
254
        auto r =
            prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
255
256
        auto s = prog.add_instruction(op::softmax{}, r);
        return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
Paul's avatar
Paul committed
257
258
    }

Shucai Xiao's avatar
Shucai Xiao committed
259
260
261
    instruction_ref parse_logsoftmax(const std::string&,
                                     const attribute_map& attributes,
                                     std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
262
263
264
265
266
267
268
269
270
271
    {
        int axis = 1;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

        return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
    }

272
    instruction_ref parse_argmax(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
273
274
                                 const attribute_map& attributes,
                                 std::vector<instruction_ref> args)
275
276
277
278
279
280
281
    {
        int axis = 0;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
282
        int keep_dims = 1;
Shucai Xiao's avatar
Shucai Xiao committed
283
        if(contains(attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
284
285
286
287
        {
            keep_dims = parse_value(attributes.at("keepdims")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
288
        if(keep_dims == 0)
289
290
291
292
293
294
295
296
        {
            auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
            return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
        }
        else
        {
            return prog.add_instruction(op::argmax{axis}, std::move(args));
        }
297
298
299
    }

    instruction_ref parse_argmin(const std::string&,
Shucai Xiao's avatar
Shucai Xiao committed
300
301
                                 const attribute_map& attributes,
                                 std::vector<instruction_ref> args)
302
303
304
305
306
307
308
    {
        int axis = 0;
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
309
        int keep_dims = 1;
Shucai Xiao's avatar
Shucai Xiao committed
310
        if(contains(attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
311
312
313
314
        {
            keep_dims = parse_value(attributes.at("keepdims")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
315
        if(keep_dims == 0)
316
317
318
319
320
321
322
323
        {
            auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
            return prog.add_instruction(op::squeeze{{static_cast<int64_t>(axis)}}, ins);
        }
        else
        {
            return prog.add_instruction(op::argmin{axis}, std::move(args));
        }
324
325
    }

Paul's avatar
Paul committed
326
    instruction_ref
Paul's avatar
Paul committed
327
    parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
328
    {
329
        op::convolution op;
330
        auto l0 = args[0];
Paul's avatar
Paul committed
331
332
        if(contains(attributes, "pads"))
        {
Scott Thornton's avatar
Scott Thornton committed
333
            if(contains(attributes, "auto_pad"))
334
            {
Paul's avatar
Paul committed
335
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
336
            }
337
338
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
339
            if(padding.size() != 4)
340
            {
Paul's avatar
Paul committed
341
                MIGRAPHX_THROW("padding should have 4 values");
342
            }
Scott Thornton's avatar
Scott Thornton committed
343
            if(padding[0] != padding[2] || padding[1] != padding[3])
344
            {
345
346
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
Khalique's avatar
Khalique committed
347
                l0      = prog.add_instruction(op::pad{padding}, l0);
348
            }
349
350
351
352
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
353
            }
Paul's avatar
Paul committed
354
        }
Paul's avatar
Paul committed
355
356
357
358
359
360
361
362
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "dilations"))
        {
            copy(attributes["dilations"].ints(), op.dilation.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
363
        if(contains(attributes, "auto_pad"))
364
365
        {
            auto s = attributes["auto_pad"].s();
Scott Thornton's avatar
Scott Thornton committed
366
            if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
367
            {
Paul's avatar
Paul committed
368
                MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
369
370
            }

wsttiger's avatar
fixes  
wsttiger committed
371
            if(s.find("SAME") != std::string::npos)
372
            {
373
                op.padding_mode = op::padding_mode_t::same;
374
375
            }
        }
Khalique's avatar
Khalique committed
376
377
378
379
        if(contains(attributes, "group"))
        {
            op.group = parse_value(attributes.at("group")).at<int>();
        }
Paul's avatar
Paul committed
380
381
382
383
        if(args.size() == 3)
        {
            uint64_t axis = 1;
            auto l1       = prog.add_instruction(op, args[0], args[1]);
Shucai Xiao's avatar
Shucai Xiao committed
384
            auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape().lens()}, args[2]);
385
            return prog.add_instruction(op::add{}, l1, l2);
Paul's avatar
Paul committed
386
        }
387
        return prog.add_instruction(op, l0, args[1]);
Paul's avatar
Paul committed
388
    }
Paul's avatar
Paul committed
389

Paul's avatar
Paul committed
390
391
392
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
                                  std::vector<instruction_ref> args)
Paul's avatar
Paul committed
393
    {
Khalique's avatar
Khalique committed
394
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
395
        auto l0 = args[0];
Khalique's avatar
Khalique committed
396
        if(starts_with(name, "Global"))
397
        {
Khalique's avatar
Khalique committed
398
399
            auto lens  = args.front()->get_shape().lens();
            op.lengths = {lens[2], lens[3]};
400
        }
Paul's avatar
Paul committed
401
402
        if(contains(attributes, "pads"))
        {
403
404
            std::vector<std::int64_t> padding;
            copy(attributes["pads"].ints(), std::back_inserter(padding));
Scott Thornton's avatar
Scott Thornton committed
405
            if(padding.size() != 4)
406
            {
Paul's avatar
Paul committed
407
                MIGRAPHX_THROW("padding should have 4 values");
408
            }
Scott Thornton's avatar
Scott Thornton committed
409
            if(padding[0] != padding[2] || padding[1] != padding[3])
410
            {
411
412
                // insert zeros for pad op (args[0] has 4 dims)
                padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
Khalique's avatar
Khalique committed
413
414
                l0 = prog.add_instruction(op::pad{padding, std::numeric_limits<float>::lowest()},
                                          l0);
415
416
417
418
419
            }
            else
            {
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
420
            }
Paul's avatar
Paul committed
421
422
423
424
425
426
427
428
429
        }
        if(contains(attributes, "strides"))
        {
            copy(attributes["strides"].ints(), op.stride.begin());
        }
        if(contains(attributes, "kernel_shape"))
        {
            copy(attributes["kernel_shape"].ints(), op.lengths.begin());
        }
Scott Thornton's avatar
Scott Thornton committed
430
        if(contains(attributes, "auto_pad"))
431
432
        {
            auto s = attributes["auto_pad"].s();
433
            if(s.find("SAME_UPPER") == std::string::npos)
434
            {
435
                MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling");
436
            }
437
            op.padding_mode = op::padding_mode_t::same;
438
439
        }

440
        return prog.add_instruction(op, l0);
Paul's avatar
Paul committed
441
442
    }

Paul's avatar
Paul committed
443
    instruction_ref
Paul's avatar
Paul committed
444
    parse_reshape(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
445
    {
446
        op::reshape op;
Paul's avatar
Paul committed
447
448
449
450
451
452
453
        if(args.size() == 1)
        {
            literal s = parse_value(attributes.at("shape"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
454
            auto s = args[1]->eval();
Paul's avatar
Paul committed
455
            if(s.empty())
Paul's avatar
Paul committed
456
                MIGRAPHX_THROW("Dynamic shape is not supported.");
Paul's avatar
Paul committed
457
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
458
        }
Paul's avatar
Paul committed
459
460
461
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
462
    instruction_ref
Paul's avatar
Paul committed
463
    parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
464
    {
465
        uint64_t axis = 1;
Paul's avatar
Paul committed
466
467
468
469
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
470
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
471
472
    }

473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    instruction_ref
    parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::squeeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

    instruction_ref
    parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::unsqueeze op;
        literal s = parse_value(attributes.at("axes"));
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        return prog.add_instruction(op, args[0]);
    }

Scott Thornton's avatar
Scott Thornton committed
491
492
493
494
495
496
497
    instruction_ref
    parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::size_t axis = parse_value(attributes.at("axis")).at<int>();
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
498

499
500
501
    instruction_ref
    parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
502
        int axis = 0;
503
504
505
506
        if(contains(attributes, "axis"))
        {
            axis = parse_value(attributes.at("axis")).at<int>();
        }
507
        op::gather op{axis};
508
509
510
        return prog.add_instruction(op, std::move(args));
    }

511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
    instruction_ref
    parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        op::slice op;
        if(contains(attributes, "axes"))
        {
            literal s = parse_value(attributes.at("axes"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
        {
            literal s = parse_value(attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
        }
        {
            literal s = parse_value(attributes.at("starts"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
        return prog.add_instruction(op, args[0]);
    }

Paul's avatar
Paul committed
531
532
533
    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
                                   const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
534
    {
Shucai Xiao's avatar
Shucai Xiao committed
535
        literal v     = parse_value(attributes.at("value"));
536
537
538
        auto dim_size = attributes.at("value").t().dims_size();
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
539
        {
540
            migraphx::shape scalar_shape{v.get_shape().type()};
541
542
543
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
544
545
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
546

Paul's avatar
Paul committed
547
    instruction_ref
Paul's avatar
Paul committed
548
    parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
549
550
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
551
        float beta  = 1.0f;
Paul's avatar
Paul committed
552
553
554
555
556
557
558
559
        bool transa = false;
        bool transb = false;
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        if(contains(attributes, "beta"))
        {
560
            beta = parse_value(attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
561
562
563
564
565
566
567
568
569
        }
        if(contains(attributes, "transA"))
        {
            transa = parse_value(attributes.at("transA")).at<bool>();
        }
        if(contains(attributes, "transB"))
        {
            transb = parse_value(attributes.at("transB")).at<bool>();
        }
570
571
572
573
574
575

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

576
577
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
578
579
        if(args.size() == 3)
        {
580
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
581
            {
Shucai Xiao's avatar
Shucai Xiao committed
582
                auto out_lens   = l1->get_shape().lens();
583
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
584
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
585
586
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
587
                {
588
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
589
                }
590
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
591
            }
Paul's avatar
Paul committed
592
        }
593
594

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
595
596
    }

597
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
598
    parse_matmul(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
599
    {
Shucai Xiao's avatar
Shucai Xiao committed
600
601
        auto l0      = args[0];
        auto l1      = args[1];
602
603
604
605
606
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
607
        if(l0_lens.size() == 1)
608
609
610
611
612
613
614
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
615
        if(l1_lens.size() == 1)
616
617
618
619
620
621
622
623
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
624
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
625
626
627
628
629
630
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
631
            l0_broadcasted_lens = output_lens;
632
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
633
            l1_broadcasted_lens = output_lens;
634
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
635
            if(l0_lens != l0_broadcasted_lens)
636
637
638
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
639
            if(l1_lens != l1_broadcasted_lens)
640
641
642
643
644
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
645
        auto dot_res     = prog.add_instruction(op::dot{1.0f, 0.0f}, bl0, bl1);
646
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
647
        if(is_a_prepended)
648
649
650
651
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
652
        if(is_b_appended)
653
654
655
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
656

657
658
659
        return dot_res;
    }

660
    instruction_ref
Paul's avatar
Paul committed
661
    parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
662
    {
Scott Thornton's avatar
Scott Thornton committed
663
664
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
665
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
Scott Thornton's avatar
Scott Thornton committed
666
        bool is_test                                      = false;
667
668
669
670
671
672
        if(contains(attributes, "epsilon"))
        {
            epsilon = parse_value(attributes.at("epsilon")).at<float>();
        }
        if(contains(attributes, "momentum"))
        {
673
            momentum = parse_value(attributes.at("momentum")).at<float>();
674
675
676
        }
        if(contains(attributes, "is_test"))
        {
wsttiger's avatar
wsttiger committed
677
            is_test = parse_value(attributes.at("is_test")).at<uint64_t>() > 0;
678
679
680
        }
        if(contains(attributes, "spatial"))
        {
681
            bn_mode = (parse_value(attributes.at("spatial")).at<uint64_t>() > 0)
682
683
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
684
        }
Paul's avatar
Paul committed
685
        (void)is_test;
Paul's avatar
Paul committed
686
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
687
        return prog.add_instruction(op, std::move(args));
688
689
    }

690
691
692
693
    instruction_ref parse_leaky_relu(const std::string&,
                                     attribute_map attributes,
                                     std::vector<instruction_ref> args)
    {
Khalique's avatar
Khalique committed
694
        float alpha = 0.01; // default alpha val for leaky relu
695
696
697
698
699
700
701
702
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
703
704
    instruction_ref
    parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
705
706
707
708
709
710
711
712
713
714
    {
        float alpha = 1.0; // default alpha val for elu
        if(contains(attributes, "alpha"))
        {
            alpha = parse_value(attributes.at("alpha")).at<float>();
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
715
716
    instruction_ref
    parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
717
718
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
719
720
721
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
Khalique's avatar
Khalique committed
722
723
724
725
726
727
728
729
730
731
732
733
        if(contains(attributes, "alpha"))
            alpha = parse_value(attributes.at("alpha")).at<float>();
        if(contains(attributes, "beta"))
            beta = parse_value(attributes.at("beta")).at<float>();
        if(contains(attributes, "bias"))
            bias = parse_value(attributes.at("bias")).at<float>();
        if(contains(attributes, "size"))
            size = parse_value(attributes.at("size")).at<int>();
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

Khalique's avatar
Khalique committed
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
    instruction_ref parse_imagescaler(const std::string&,
                                      attribute_map attributes,
                                      std::vector<instruction_ref> args)
    {
        float scale = 1.0;
        std::vector<float> bias{};
        if(contains(attributes, "scale"))
        {
            scale = parse_value(attributes.at("scale")).at<float>();
        }

        if(contains(attributes, "bias"))
        {
            auto&& bias_floats = attributes["bias"].floats();
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
750
        auto input_lens = args.front()->get_shape().lens();
Khalique's avatar
Khalique committed
751

Khalique's avatar
Khalique committed
752
753
        auto scale_val = prog.add_literal(scale);
        auto bias_vals = prog.add_literal(
Paul's avatar
Paul committed
754
            migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
755

756
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
757
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
758
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
759
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
760
    }
Khalique's avatar
Khalique committed
761

Khalique's avatar
Khalique committed
762
763
    instruction_ref
    parse_transpose(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
764
765
766
767
768
769
770
    {
        std::vector<int64_t> perm{};
        if(contains(attributes, "perm"))
        {
            auto&& perm_vals = attributes["perm"].ints();
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
771
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
772
773
    }

Khalique's avatar
Khalique committed
774
775
776
777
778
779
780
781
782
783
    instruction_ref
    parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        std::vector<int64_t> pads{};
        float value = 0.0f;
        if(contains(attributes, "pads"))
        {
            auto&& pad_vals = attributes["pads"].ints();
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
784
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
785
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
786
787
788
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
Khalique's avatar
Khalique committed
789
790
791
792
793
794
795
796
797
798
799
800
        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }
        if(contains(attributes, "mode"))
        {
            auto mode = attributes.at("mode").s();
            if(mode != "constant")
                MIGRAPHX_THROW("migraphx currently only supports constant padding");
        }
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
801
802
803
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
Shucai Xiao's avatar
Shucai Xiao committed
804
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
805
806
    {
        if(args.size() != 1)
807
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
    instruction_ref parse_constant_fill(const std::string&,
                                        attribute_map attributes,
                                        std::vector<instruction_ref> args)
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

        if(contains(attributes, "dtype"))
        {
            dtype = parse_value(attributes.at("dtype")).at<int>();
        }
        migraphx::shape::type_t type = get_type(dtype);

        if(contains(attributes, "input_as_shape"))
        {
            input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
        }

        if(contains(attributes, "value"))
        {
            value = parse_value(attributes.at("value")).at<float>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
844
845
        if(contains(attributes, "extra_shape"))
        {
846
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
847
848
        }

849
850
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
851
            if(args.size() != 1)
852
            {
853
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
854
855
            }

Shucai Xiao's avatar
Shucai Xiao committed
856
857
            if(contains(attributes, "shape"))
            {
858
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
859
                               "at the same time");
860
861
            }

862
863
864
            migraphx::argument in = args[0]->eval();
            if(in.empty())
            {
865
                MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
866
            }
867

868
869
870
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
871
872
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
873
874
875
        }
        else if(input_as_shape == 0)
        {
Shucai Xiao's avatar
Shucai Xiao committed
876
877
            if(!contains(attributes, "shape"))
            {
878
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
879
880
881
            }

            literal ls = parse_value(attributes.at("shape"));
882
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
883
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
884
            migraphx::shape s{type, dims};
885
886
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
887
888
889
        }
        else
        {
890
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
891
892
893
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
894
    std::vector<instruction_ref>
Shucai Xiao's avatar
Shucai Xiao committed
895
896
897
    parse_rnn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
898
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
899
900
901

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
902
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
903
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
904
905
906
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
907
908
909
910
911
912
913
914
915
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

916
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
917
918
        if(direction == "bidirectional")
        {
919
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
920
921
922
        }
        else if(direction == "reverse")
        {
923
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
924
925
        }

926
        std::vector<std::string> vec_names{"tanh"};
927
928
929
930
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
931
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
932
933
934
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
935
936
        }

937
938
939
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
940
        if(name_it != vec_names.end())
941
942
943
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
944

Shucai Xiao's avatar
Shucai Xiao committed
945
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
946
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
947
        // if only one actv function is provided, we use it in both
948
        // forward and reverse direction
949
        if(dirct == op::rnn_direction::bidirectional)
950
        {
Shucai Xiao's avatar
Shucai Xiao committed
951
            if(vec_names.size() == 1)
952
953
954
955
956
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
957
958
959
        std::vector<operation> vec_actv_funcs(vec_names.size());
        std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& fn) {
            return map_actv_funcs[fn];
960
        });
Shucai Xiao's avatar
Shucai Xiao committed
961

Shucai Xiao's avatar
Shucai Xiao committed
962
963
964
965
966
967
968
        // To be added later
        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

969
970
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
971
        if(args.size() < 6)
972
973
974
975
976
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
977
978
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
979
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
980

981
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
982
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
983

Shucai Xiao's avatar
Shucai Xiao committed
984
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
985
986
    }

987
    std::vector<instruction_ref>
988
989
990
991
992
993
994
    parse_gru(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
Shucai Xiao's avatar
Shucai Xiao committed
995
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
996
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
997
998
999
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1000
1001
1002
1003
1004
1005
1006
1007
1008
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

1009
        op::rnn_direction dirct = op::rnn_direction::forward;
1010
1011
        if(direction == "bidirectional")
        {
1012
            dirct = op::rnn_direction::bidirectional;
1013
1014
1015
        }
        else if(direction == "reverse")
        {
1016
            dirct = op::rnn_direction::reverse;
1017
1018
        }

1019
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1020
1021
        if(contains(attributes, "activations"))
        {
1022
            auto names = attributes.at("activations").strings();
1023
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1024
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1025
1026
1027
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1028
1029
        }

1030
        // need 4 activation functions
1031
        if(dirct == op::rnn_direction::bidirectional)
1032
        {
Shucai Xiao's avatar
Shucai Xiao committed
1033
            // 4 activation functions are used in the bidirectional
1034
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1035
1036
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1037
1038
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1039
1040
1041
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1042
            if(vec_names.size() == 1)
1043
            {
1044
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1045
            }
1046
            else if(vec_names.size() == 2)
1047
            {
1048
1049
1050
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1051
            }
1052
            else if(vec_names.size() == 3)
1053
            {
1054
                vec_names.push_back(vec_names.at(2));
1055
1056
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1057
        else
1058
        {
1059
            if(vec_names.size() == 1)
1060
            {
1061
                vec_names.push_back(vec_names.at(0));
1062
1063
1064
            }
        }

1065
1066
1067
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1068
        if(name_it != vec_names.end())
1069
1070
1071
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1072

Shucai Xiao's avatar
Shucai Xiao committed
1073
1074
1075
        std::vector<operation> vec_actv_funcs(vec_names.size());
        std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
            return map_actv_funcs[name];
Shucai Xiao's avatar
Shucai Xiao committed
1076
        });
1077
1078
1079
1080
1081
1082
1083
1084

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int linear_before_reset = 0;
Shucai Xiao's avatar
Shucai Xiao committed
1085
        if(contains(attributes, "linear_before_reset"))
1086
1087
1088
1089
        {
            linear_before_reset = parse_value(attributes.at("linear_before_reset")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1090
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1091
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1092
1093
1094
1095
1096
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1097
1098
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1099
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1100
            std::move(args));
1101
1102

        // second output for last gru output
1103
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
1104

Shucai Xiao's avatar
Shucai Xiao committed
1105
        return {hidden_states, last_output};
1106
1107
    }

Shucai Xiao's avatar
Shucai Xiao committed
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
    std::vector<instruction_ref>
    parse_lstm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

        if(contains(attributes, "hidden_size"))
        {
            std::size_t hidden_size_att = parse_value(attributes.at("hidden_size")).at<int>();
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
        if(contains(attributes, "direction"))
        {
            direction = attributes.at("direction").s();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1130
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1131
1132
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1133
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1134
1135
1136
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1137
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1138
        }
Shucai Xiao's avatar
Shucai Xiao committed
1139
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1140
        {
Shucai Xiao's avatar
Shucai Xiao committed
1141
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1142
1143
1144
1145
1146
1147
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1148
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
Shucai Xiao's avatar
Shucai Xiao committed
1149
1150
1151
1152
1153
        if(contains(attributes, "activations"))
        {
            auto names = attributes.at("activations").strings();
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1154
1155
1156
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1157
1158
1159
        }

        // need 6 activation functions for bidirectional directions
Shucai Xiao's avatar
Shucai Xiao committed
1160
        if(dirct == op::rnn_direction::bidirectional)
Shucai Xiao's avatar
Shucai Xiao committed
1161
1162
1163
1164
1165
1166
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
Shucai Xiao's avatar
Shucai Xiao committed
1167
            // if 3 actv funcs are provide, repeat all three once.
Shucai Xiao's avatar
Shucai Xiao committed
1168
1169
1170
1171
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(vec_names.size())
            {
1172
            case 1:
Shucai Xiao's avatar
Shucai Xiao committed
1173
1174
1175
1176
1177
1178
                vec_names = {vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0),
                             vec_names.at(0)};
1179
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1180
1181
1182

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
Shucai Xiao's avatar
Shucai Xiao committed
1183
1184
1185
1186
1187
1188
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1189
1190
1191
1192
                break;

            case 3:
                // repeat all three actv funcs once
Shucai Xiao's avatar
Shucai Xiao committed
1193
1194
1195
1196
1197
1198
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2)};
Shucai Xiao's avatar
Shucai Xiao committed
1199
1200
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1201
1202
1203
1204
1205
1206
1207
            case 4:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(3),
                             vec_names.at(3)};
1208
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1209

Shucai Xiao's avatar
Shucai Xiao committed
1210
1211
1212
1213
1214
1215
1216
            case 5:
                vec_names = {vec_names.at(0),
                             vec_names.at(1),
                             vec_names.at(2),
                             vec_names.at(3),
                             vec_names.at(4),
                             vec_names.at(4)};
1217
                break;
Shucai Xiao's avatar
Shucai Xiao committed
1218

Shucai Xiao's avatar
Shucai Xiao committed
1219
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1220
1221
1222
1223
1224
1225
            }
        }
        else
        {
            switch(vec_names.size())
            {
Shucai Xiao's avatar
Shucai Xiao committed
1226
            case 1: vec_names = {vec_names.at(0), vec_names.at(0), vec_names.at(0)}; break;
Shucai Xiao's avatar
Shucai Xiao committed
1227
1228
1229

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
Shucai Xiao's avatar
Shucai Xiao committed
1230
                vec_names = {vec_names.at(0), vec_names.at(1), vec_names.at(1)};
Shucai Xiao's avatar
Shucai Xiao committed
1231
1232
                break;

Shucai Xiao's avatar
Shucai Xiao committed
1233
            default: break;
Shucai Xiao's avatar
Shucai Xiao committed
1234
1235
1236
            }
        }

1237
1238
1239
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1240
        if(name_it != vec_names.end())
1241
1242
1243
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265

        std::vector<operation> vec_actv_funcs(vec_names.size());
        std::transform(vec_names.begin(), vec_names.end(), vec_actv_funcs.begin(), [&](auto& name) {
            return map_actv_funcs[name];
        });

        float clip = 0.0;
        if(contains(attributes, "clip"))
        {
            clip = parse_value(attributes.at("clip")).at<float>();
        }

        int input_forget = 0;
        if(contains(attributes, "input_forget"))
        {
            input_forget = parse_value(attributes.at("input_forget")).at<int>();
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
1266
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
1267
1268
1269
1270
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1271
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1272
1273

        // second output for last lstm output
Shucai Xiao's avatar
Shucai Xiao committed
1274
        auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1275
1276
1277
1278
1279
1280

        // third output for last cell output
        auto last_cell_output = prog.add_instruction(op::lstm_last_cell_output{}, hidden_states);

        return {hidden_states, last_output, last_cell_output};
    }
1281

Shucai Xiao's avatar
Shucai Xiao committed
1282
1283
1284
    instruction_ref parse_reduce_sum(const std::string&,
                                     attribute_map attributes,
                                     std::vector<instruction_ref> args)
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
        std::vector<std::size_t> axes(n_dim);
        std::iota(axes.begin(), axes.end(), 0);
        if(contains(attributes, "axes"))
        {
            axes.clear();
            auto&& attr_axes = attributes["axes"].ints();
Shucai Xiao's avatar
Shucai Xiao committed
1295
            axes             = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
1296
1297
1298
        }

        int keep_dims = 1;
Shucai Xiao's avatar
Shucai Xiao committed
1299
        if(contains(attributes, "keepdims"))
1300
1301
1302
1303
        {
            keep_dims = parse_value(attributes.at("keepdims")).at<int>();
        }

Shucai Xiao's avatar
Shucai Xiao committed
1304
        if(keep_dims == 1)
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
        {
            return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
        }
        else
        {
            auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
            std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()};
            return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
1315

Paul's avatar
Paul committed
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
1328
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
1329
1330
1331
1332
1333
1334
        }
    }

    void parse_graph(const onnx::GraphProto& graph)
    {
        nodes = get_nodes(graph);
1335
1336
1337
1338
1339
        std::unordered_map<std::string, onnx::TensorProto> initializer_data;
        for(auto&& f : graph.initializer())
        {
            initializer_data[f.name()] = f;
        }
Paul's avatar
Paul committed
1340
1341
1342
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
            // Does the input have an initializer?
            if(contains(initializer_data, name))
            {
                auto t             = initializer_data[name];
                instructions[name] = prog.add_literal(parse_tensor(t));
            }
            else
            {
                // TODO: Get shape of input parameter
                shape s            = parse_type(input.type());
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
1355
        }
Paul's avatar
Paul committed
1356
        for(auto&& output : graph.output())
Paul's avatar
Paul committed
1357
        {
Paul's avatar
Paul committed
1358
            this->parse_node(output.name());
Paul's avatar
Paul committed
1359
1360
1361
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1362
    void parse_undefined(const std::string& name)
1363
    {
Shucai Xiao's avatar
Shucai Xiao committed
1364
        auto ins           = prog.add_instruction(op::undefined{});
1365
1366
1367
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
1368
    void parse_node(const std::string& name)
Paul's avatar
Paul committed
1369
    {
Paul's avatar
Paul committed
1370
        if(name.empty())
Paul's avatar
Paul committed
1371
            MIGRAPHX_THROW("Onnx node must have a name");
Paul's avatar
Paul committed
1372
1373
1374
1375
1376
1377
1378
1379
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(nodes.count(input) > 0)
                {
Paul's avatar
Paul committed
1380
1381
                    assert(name != input);
                    this->parse_node(input);
Paul's avatar
Paul committed
1382
                }
Shucai Xiao's avatar
Shucai Xiao committed
1383
                else if(input.empty())
Paul's avatar
Paul committed
1384
                {
1385
                    this->parse_undefined(input);
Paul's avatar
Paul committed
1386
                }
1387
                args.push_back(instructions.at(input));
Paul's avatar
Paul committed
1388
            }
Paul's avatar
Paul committed
1389
            std::vector<instruction_ref> result;
Paul's avatar
Paul committed
1390
1391
            if(ops.count(node.op_type()) == 0)
            {
1392
                result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
Paul's avatar
Paul committed
1393
1394
1395
            }
            else
            {
Paul's avatar
Paul committed
1396
                result = ops[node.op_type()](get_attributes(node), args);
Paul's avatar
Paul committed
1397
            }
Paul's avatar
Paul committed
1398
            // Even no output nodes produce output in migraphx
Paul's avatar
Paul committed
1399
            if(node.output().empty() and result.size() == 1)
Paul's avatar
Paul committed
1400
1401
            {
                instructions[name] = result.front();
Paul's avatar
Paul committed
1402
1403
1404
            }
            else
            {
Paul's avatar
Paul committed
1405
1406
1407
1408
1409
1410
                assert(node.output().size() >= result.size());
                std::transform(result.begin(),
                               result.end(),
                               node.output().begin(),
                               std::inserter(instructions, instructions.end()),
                               [](auto&& x, auto&& y) { return std::make_pair(y, x); });
Paul's avatar
Paul committed
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
            }
        }
    }

    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    static node_map get_nodes(const onnx::GraphProto& graph)
    {
        std::unordered_map<std::string, onnx::NodeProto> result;
Paul's avatar
Paul committed
1428
        std::size_t n = 0;
Paul's avatar
Paul committed
1429
1430
        for(auto&& node : graph.node())
        {
Paul's avatar
Paul committed
1431
            if(node.output().empty())
Paul's avatar
Paul committed
1432
            {
Paul's avatar
Paul committed
1433
                if(node.name().empty())
Paul's avatar
Paul committed
1434
1435
1436
1437
1438
1439
1440
1441
1442
                {
                    result["migraphx_unamed_node_" + std::to_string(n)] = node;
                    n++;
                }
                else
                {
                    result[node.name()] = node;
                }
            }
Paul's avatar
Paul committed
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
            for(auto&& output : node.output())
            {
                result[output] = node;
            }
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::UNDEFINED: return {};
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::STRING: return {};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
        case onnx::AttributeProto::GRAPH: return {};
Paul's avatar
Paul committed
1468
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
1469
1470
1471
1472
1473
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
        case onnx::AttributeProto::STRINGS: return {};
        case onnx::AttributeProto::TENSORS: return {};
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
1474
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
1475
1476
1477
1478
1479
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
1480
1481
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
1482
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
1483
1484
1485
            switch(t.data_type())
            {
            case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
1486
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Scott Thornton's avatar
Scott Thornton committed
1487
            case onnx::TensorProto::UINT8: throw std::runtime_error("");
1488
            case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data());
Khalique's avatar
Khalique committed
1489
1490
            case onnx::TensorProto::UINT16:
                return create_literal(shape::int32_type, dims, s.data());
1491
1492
1493
            case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
            case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Scott Thornton's avatar
Scott Thornton committed
1494
            case onnx::TensorProto::STRING: throw std::runtime_error("");
1495
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Khalique's avatar
Khalique committed
1496
1497
1498
1499
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
Scott Thornton's avatar
Scott Thornton committed
1500
1501
1502
1503
1504
            case onnx::TensorProto::UINT32: throw std::runtime_error("");
            case onnx::TensorProto::UINT64: throw std::runtime_error("");
            case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
1505
            MIGRAPHX_THROW("Invalid tensor type");
1506
        }
Paul's avatar
Paul committed
1507
1508
1509
1510
        switch(t.data_type())
        {
        case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
        case onnx::TensorProto::FLOAT:
Khalique's avatar
Khalique committed
1511
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
1512
1513
        case onnx::TensorProto::UINT8: throw std::runtime_error("");
        case onnx::TensorProto::INT8:
Khalique's avatar
Khalique committed
1514
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1515
        case onnx::TensorProto::UINT16:
Khalique's avatar
Khalique committed
1516
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1517
        case onnx::TensorProto::INT16:
Khalique's avatar
Khalique committed
1518
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1519
        case onnx::TensorProto::INT32:
Khalique's avatar
Khalique committed
1520
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1521
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
1522
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
1523
1524
        case onnx::TensorProto::STRING: throw std::runtime_error("");
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
1525
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
1526
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
1527
        {
Khalique's avatar
Khalique committed
1528
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
1529
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
1530
1531
1532
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
1533
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1534
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1535
        }
Paul's avatar
Paul committed
1536
        case onnx::TensorProto::DOUBLE:
Khalique's avatar
Khalique committed
1537
            return create_literal(shape::double_type, dims, t.double_data());
Paul's avatar
Paul committed
1538
1539
1540
1541
1542
        case onnx::TensorProto::UINT32: throw std::runtime_error("");
        case onnx::TensorProto::UINT64: throw std::runtime_error("");
        case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
1543
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
1544
1545
    }

Khalique's avatar
Khalique committed
1546
    static literal
1547
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
1548
    {
Khalique's avatar
Khalique committed
1549
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
1550
        if(dims.empty())
1551
            return literal{{shape_type}, data};
1552
1553
1554
        return literal{{shape_type, dims}, data};
    }

1555
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
1556
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
1557
1558
    {
        if(dims.empty())
1559
            return literal{{shape_type}, data.begin(), data.end()};
1560
        return literal{{shape_type, dims}, data.begin(), data.end()};
1561
1562
    }

Paul's avatar
Paul committed
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
    static shape parse_type(const onnx::TypeProto& t)
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::UNDEFINED:
            break; // throw std::runtime_error("Unsupported type UNDEFINED");
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::UINT8:
            break; // throw std::runtime_error("Unsupported type UINT8");
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
        case onnx::TensorProto::STRING:
            break; // throw std::runtime_error("Unsupported type STRING");
        case onnx::TensorProto::BOOL:
            break; // throw std::runtime_error("Unsupported type BOOL");
Paul's avatar
Paul committed
1582
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
1583
1584
1585
1586
1587
1588
1589
1590
1591
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
        case onnx::TensorProto::COMPLEX64:
            break; // throw std::runtime_error("Unsupported type COMPLEX64");
        case onnx::TensorProto::COMPLEX128:
            break; // throw std::runtime_error("Unsupported type COMPLEX128");
        }
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
1592
        auto&& tensor_dims = t.tensor_type().shape().dim();
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
                       [](auto&& d) -> std::size_t {
                           if(not d.has_dim_value())
                           {
                               long default_batch_size = 1; // FIXME
                               return default_batch_size;
                           }
                           return d.dim_value();
                       });
Paul's avatar
Paul committed
1604
1605
        return {shape_type, dims};
    }
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Paul's avatar
Paul committed
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
};

program parse_onnx(const std::string& name)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    onnx_parser parser;
#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
    return std::move(parser.prog);
}

Paul's avatar
Paul committed
1651
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
1652
} // namespace migraphx