onnx.cpp 99.4 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
14
15
16
17
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Paul's avatar
Paul committed
20
21

namespace migraphx {
Paul's avatar
Paul committed
22
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
23

24
25
namespace onnx = onnx_for_migraphx;

Paul's avatar
Paul committed
26
27
28
struct onnx_parser
{
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
29
30
31
32
33
34
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
35
    using op_func =
36
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
37
38
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
39
40
41
42
    program prog                  = program();
    bool is_pytorch               = false;
    std::size_t default_dim_value = 1;
    std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
43
    bool skip_unknown_operators = false;
Paul's avatar
Paul committed
44
45

    std::unordered_map<std::string, op_func> ops;
46
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
47
48
49

    onnx_parser()
    {
50
        // sort onnx operator alphabetically through name
Khalique's avatar
Khalique committed
51
        add_generic_op("Abs", op::abs{});
52
53
54
55
56
57
58
59
60
        add_generic_op("Acos", op::acos{});
        add_generic_op("Acosh", op::acosh{});
        add_generic_op("Asin", op::asin{});
        add_generic_op("Asinh", op::asinh{});
        add_generic_op("Atan", op::atan{});
        add_generic_op("Atanh", op::atanh{});
        add_generic_op("Ceil", op::ceil{});
        add_generic_op("Cos", op::cos{});
        add_generic_op("Cosh", op::cosh{});
Shucai Xiao's avatar
Shucai Xiao committed
61
        add_generic_op("Erf", op::erf{});
62
        add_generic_op("Exp", op::exp{});
Khalique's avatar
Khalique committed
63
        add_generic_op("Dropout", op::identity{});
64
        add_generic_op("Floor", op::floor{});
Khalique's avatar
Khalique committed
65
        add_generic_op("Identity", op::identity{});
Shucai Xiao's avatar
Shucai Xiao committed
66
67
        add_generic_op("Log", op::log{});
        add_generic_op("Neg", op::neg{});
kahmed10's avatar
kahmed10 committed
68
        add_generic_op("Reciprocal", op::recip{});
69
70
71
72
        add_generic_op("Relu", op::relu{});
        add_generic_op("Round", op::round{});
        add_generic_op("Sigmoid", op::sigmoid{});
        add_generic_op("Sign", op::sign{});
Shucai Xiao's avatar
Shucai Xiao committed
73
        add_generic_op("Sin", op::sin{});
74
        add_generic_op("Sinh", op::sinh{});
75
        add_generic_op("Sqrt", op::sqrt{});
76
77
        add_generic_op("Tan", op::tan{});
        add_generic_op("Tanh", op::tanh{});
Paul's avatar
Paul committed
78

Khalique's avatar
Khalique committed
79
80
81
        add_binary_op("Add", op::add{});
        add_binary_op("Div", op::div{});
        add_binary_op("Mul", op::mul{});
Shucai Xiao's avatar
Shucai Xiao committed
82
        add_binary_op("Pow", op::pow{});
Shucai Xiao's avatar
Shucai Xiao committed
83
        add_binary_op("PRelu", op::prelu{});
84
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
85

Khalique's avatar
Khalique committed
86
87
88
        add_variadic_op("Sum", op::add{});
        add_variadic_op("Max", op::max{});
        add_variadic_op("Min", op::min{});
Paul's avatar
Paul committed
89

90
        add_mem_op("ATen", &onnx_parser::parse_aten);
91
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
92
93
        add_mem_op("ArgMax", &onnx_parser::parse_arg_op<op::argmax>);
        add_mem_op("ArgMin", &onnx_parser::parse_arg_op<op::argmin>);
94
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
95
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
96
        add_mem_op("Clip", &onnx_parser::parse_clip);
97
        add_mem_op("Concat", &onnx_parser::parse_concat);
Paul's avatar
Paul committed
98
        add_mem_op("Constant", &onnx_parser::parse_constant);
99
100
101
102
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
        add_mem_op("Conv", &onnx_parser::parse_conv<op::convolution>);
        add_mem_op("ConvInteger", &onnx_parser::parse_conv<op::quant_convolution>);
kahmed10's avatar
kahmed10 committed
103
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
104
105
        add_mem_op("Elu", &onnx_parser::parse_elu);
        add_mem_op("Expand", &onnx_parser::parse_expand);
Paul's avatar
Paul committed
106
        add_mem_op("Flatten", &onnx_parser::parse_flatten);
107
        add_mem_op("Gather", &onnx_parser::parse_gather);
Shucai Xiao's avatar
Shucai Xiao committed
108
        add_mem_op("GatherElements", &onnx_parser::parse_gather_elements);
Paul's avatar
Paul committed
109
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
110
111
112
113
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
114
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
115
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
116
        add_mem_op("LogSoftmax", &onnx_parser::parse_softmax<op::logsoftmax>);
117
        add_mem_op("LRN", &onnx_parser::parse_lrn);
118
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
119
120
121
        add_mem_op("MatMul", &onnx_parser::parse_matmul<op::dot>);
        add_mem_op("MatMulInteger", &onnx_parser::parse_matmul<op::quant_dot>);
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
kahmed10's avatar
kahmed10 committed
122
        add_mem_op("OneHot", &onnx_parser::parse_onehot);
123
        add_mem_op("Pad", &onnx_parser::parse_pad);
kahmed10's avatar
kahmed10 committed
124
        add_mem_op("Range", &onnx_parser::parse_range);
Shucai Xiao's avatar
Shucai Xiao committed
125
126
127
128
129
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
        add_mem_op("ReduceMax", &onnx_parser::parse_reduce_oper<op::reduce_max>);
Shucai Xiao's avatar
Shucai Xiao committed
130
        add_mem_op("ReduceMean", &onnx_parser::parse_reduce_oper<op::reduce_mean>);
Shucai Xiao's avatar
Shucai Xiao committed
131
        add_mem_op("ReduceMin", &onnx_parser::parse_reduce_oper<op::reduce_min>);
Shucai Xiao's avatar
Shucai Xiao committed
132
133
134
        add_mem_op("ReduceProd", &onnx_parser::parse_reduce_oper<op::reduce_prod>);
        add_mem_op("ReduceSum", &onnx_parser::parse_reduce_oper<op::reduce_sum>);
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
135
136
137
138
139
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
        add_mem_op("RNN", &onnx_parser::parse_rnn);
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
        add_mem_op("Softmax", &onnx_parser::parse_softmax<op::softmax>);
140
        add_mem_op("Split", &onnx_parser::parse_split);
141
        add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
kahmed10's avatar
kahmed10 committed
142
        add_mem_op("Tile", &onnx_parser::parse_tile);
143
144
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
        add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
145
146
147
148
149
150
151

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
152
153
154
155
156
157
        // Support name format of all lower case or the first letter capital
        map_actv_funcs.insert(std::make_pair("tanh", op::tanh{}));
        map_actv_funcs.insert(std::make_pair("relu", op::relu{}));
        map_actv_funcs.insert(std::make_pair("sigmoid", op::sigmoid{}));
        map_actv_funcs.insert(std::make_pair("leakyrelu", op::leaky_relu{}));
        map_actv_funcs.insert(std::make_pair("elu", op::elu{}));
Paul's avatar
Paul committed
158
159
160
161
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
162
163
164
165
166
167
168
169
170
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
171
172
173
174
175
176
177
    {
        ops.emplace(name, f);
    }

    template <class F>
    void add_mem_op(std::string name, F f)
    {
Paul's avatar
Paul committed
178
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
179
180
181
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
182

183
    template <class T>
Khalique's avatar
Khalique committed
184
    void add_binary_op(std::string name, T x)
185
    {
186
        add_op(name, [this, x](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
187
            if(args.size() != 2)
Paul's avatar
Paul committed
188
                MIGRAPHX_THROW("binary operators should have 2 operands");
189
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
190
            {
191
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
192
193
                if(broadcasted != 0)
                {
194
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
Shucai Xiao's avatar
Shucai Xiao committed
195
196
                    auto l = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                  args[1]);
197
198
                    return prog.add_instruction(x, args[0], l);
                }
199
                return prog.add_instruction(x, args);
200
            }
Paul's avatar
Paul committed
201
            else
202
            {
Khalique's avatar
Khalique committed
203
                return add_broadcastable_binary_op(args[0], args[1], x);
204
205
206
207
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
208
209
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
210
211
212
213
214
215
216
217
218
219
220
221
222
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
223
        if(s0.size() > s1.size())
224
225
226
227
228
229
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
230
231
232
233
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
234
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
235
                           if(a != b and a != 1 and b != 1)
236
                           {
Shucai Xiao's avatar
Shucai Xiao committed
237
238
239
240
241
242
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
243
244
245
246

        return out_lens;
    }

Shucai Xiao's avatar
Shucai Xiao committed
247
248
    instruction_ref make_contiguous(instruction_ref ins)
    {
Shucai Xiao's avatar
Shucai Xiao committed
249
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
250
251
252
253
254
255
256
        {
            return ins;
        }

        return prog.add_instruction(op::contiguous{}, ins);
    }

Khalique's avatar
Khalique committed
257
258
259
    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
260
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
261
262
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
263
264
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
265
            auto out_lens = compute_broadcasted_lens(s0, s1);
266
267
268
269
270
271
272
273
274

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
                l0 = prog.add_instruction(op::multibroadcast{out_lens}, arg0);

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
                l1 = prog.add_instruction(op::multibroadcast{out_lens}, arg1);

Khalique's avatar
Khalique committed
275
276
277
278
279
280
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
281
282
    }

Paul's avatar
Paul committed
283
    template <class T>
Paul's avatar
Paul committed
284
285
    void add_generic_op(std::string name, T x)
    {
286
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Paul's avatar
Paul committed
287
288
289
290
            return prog.add_instruction(x, args);
        });
    }

Khalique's avatar
Khalique committed
291
    template <class T>
Khalique's avatar
Khalique committed
292
    void add_variadic_op(std::string name, T x)
Khalique's avatar
Khalique committed
293
    {
294
        add_op(name, [this, x](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
295
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
296
297
298
299
300
                                   args.end(),
                                   args.front(),
                                   [this, x](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, x);
                                   });
Khalique's avatar
Khalique committed
301
        });
Khalique's avatar
Khalique committed
302
303
    }

kahmed10's avatar
kahmed10 committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

    instruction_ref
    add_bias(const std::vector<instruction_ref>& args, instruction_ref curr_ins, uint64_t axis)
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
                prog.add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return prog.add_instruction(op::add{}, curr_ins, bias_bcast);
        }
        return curr_ins;
    }

323
    static bool is_asym_padding(const std::vector<int64_t>& padding)
324
    {
325
326
327
328
329
330
331
        assert(padding.size() % 2 == 0);
        size_t pad_ndims = padding.size() / 2;

        for(size_t i = 0; i < pad_ndims; i++)
        {
            if(padding[i] != padding[i + pad_ndims])
            {
kahmed10's avatar
kahmed10 committed
332
                return true;
333
334
            }
        }
kahmed10's avatar
kahmed10 committed
335
336
        return false;
    }
337

kahmed10's avatar
kahmed10 committed
338
339
340
341
    template <class Op>
    void check_asym_padding(instruction_ref& ins,
                            const std::vector<int64_t>& padding,
                            Op& op,
342
343
                            int count_include_pad = 0,
                            float pad_val         = 0)
kahmed10's avatar
kahmed10 committed
344
345
346
347
348
    {
        size_t pad_ndims  = padding.size() / 2;
        auto left_pad_it  = padding.begin();
        auto right_pad_it = left_pad_it + pad_ndims;

349
        if(is_asym_padding(padding) or count_include_pad == 1)
350
        {
351
352
353
354
355
356
            std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C
            // add left pads
            asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it);
            // add right pads
            asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
            ins = prog.add_instruction(op::pad{asym_pads, pad_val}, ins);
357
358
359
        }
        else
        {
360
            op.padding = std::vector<size_t>(left_pad_it, right_pad_it);
361
362
363
        }
    }

364
365
    instruction_ref
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
366
    {
kahmed10's avatar
kahmed10 committed
367
368
369
370
371
372
373
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3)
Khalique's avatar
Khalique committed
374
        {
kahmed10's avatar
kahmed10 committed
375
376
377
378
            min_arg  = args[1];
            max_arg  = args[2];
            min_used = true;
            max_used = true;
Khalique's avatar
Khalique committed
379
        }
kahmed10's avatar
kahmed10 committed
380
        else if(args.size() == 2)
Khalique's avatar
Khalique committed
381
        {
kahmed10's avatar
kahmed10 committed
382
383
384
385
386
387
388
389
390
391
392
393
394
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
            min_arg       = prog.add_literal(min_val);
            max_arg       = prog.add_literal(max_val);
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
395
        }
kahmed10's avatar
kahmed10 committed
396
397
398
399
400
401
402
403
404
405
406
407
408

        if(min_used)
            min_arg = prog.add_instruction(op::multibroadcast{input_lens}, min_arg);

        if(max_used)
            max_arg = prog.add_instruction(op::multibroadcast{input_lens}, max_arg);

        if(min_used and max_used)
            return prog.add_instruction(op::clip{}, args[0], min_arg, max_arg);
        if(min_used)
            return prog.add_instruction(op::max{}, args[0], min_arg);

        return prog.add_instruction(op::identity{}, args[0]);
Khalique's avatar
Khalique committed
409
410
    }

Shucai Xiao's avatar
Shucai Xiao committed
411
    template <class Op>
412
413
    instruction_ref
    parse_softmax(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
414
    {
415
        int64_t axis = 1;
416
        if(contains(info.attributes, "axis"))
417
        {
418
            axis = parse_value(info.attributes.at("axis")).at<int>();
419
420
        }

421
        return prog.add_instruction(Op{axis}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
422
423
    }

Shucai Xiao's avatar
Shucai Xiao committed
424
    template <class Op>
425
426
    instruction_ref
    parse_arg_op(const std::string&, node_info info, std::vector<instruction_ref> args)
427
    {
428
        int64_t axis = 0;
429
        if(contains(info.attributes, "axis"))
430
        {
431
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
432
433
        }

Shucai Xiao's avatar
Shucai Xiao committed
434
        int keep_dims = 1;
435
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
436
        {
437
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
438
439
        }

Shucai Xiao's avatar
Shucai Xiao committed
440
        if(keep_dims == 0)
441
        {
442
            auto ins = prog.add_instruction(Op{axis}, std::move(args));
443
            return prog.add_instruction(op::squeeze{{axis}}, ins);
444
445
446
        }
        else
        {
447
            return prog.add_instruction(Op{axis}, std::move(args));
448
        }
449
450
    }

kahmed10's avatar
kahmed10 committed
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    void calc_reflect_indices(std::vector<int>& indices, const int64_t num_dims)
    {
        int k         = 0;
        bool reversed = false;
        // in reflect padding, if the num_pads > num_dims,
        // compute the extra pad indices periodically, ex. ( 1, 2, 3, 2, 1, 0)
        for(int& idx : indices)
        {
            if(k == num_dims - 1)
                reversed = true;
            if(k == 0)
                reversed = false;
            if(reversed)
                k--;
            else
                k++;
            idx = k;
        }
    }

    instruction_ref reflect_pad(const std::vector<int64_t>& pads, instruction_ref input)
    {
        size_t num_dims = pads.size() / 2;
        std::vector<int> ldims(pads.begin(), pads.begin() + num_dims);
        std::vector<int> rdims(pads.begin() + num_dims, pads.end());
        assert(ldims.size() == rdims.size());

        std::vector<int64_t> axes(num_dims);
        std::iota(axes.begin(), axes.end(), int64_t{0});

        // iterate over dimensions, starting from lowest dimension
        for(int64_t i = num_dims - 1; i >= 0; i--)
        {
            auto axis   = i;
            auto lcount = ldims.at(i);
            auto rcount = rdims.at(i);
            if(lcount == 0 and rcount == 0) // no padding for current dim
                continue;

            // calculate starts and ends for each iteration since shape may change
            std::vector<size_t> dims = input->get_shape().lens();
            std::vector<int64_t> starts(axes.size(), 0);
            std::vector<int64_t> ends(dims.begin(), dims.end());
            std::vector<instruction_ref> slices;

            auto starts_it = starts.begin() + i;
            auto ends_it   = ends.begin() + i;
            auto dims_it   = dims.begin() + i;

            std::vector<int> l_indices(lcount);
            std::vector<int> r_indices(rcount);

            // compute slice indices in a periodic fashion
            calc_reflect_indices(l_indices, *dims_it);
            calc_reflect_indices(r_indices, *dims_it);

            for(int idx : l_indices)
            {
                *starts_it = idx;
                *ends_it   = *starts_it + 1;
                slices.push_back(prog.add_instruction(op::slice{axes, starts, ends}, input));
            }
            // when padding on the left side, the outermost pad should be at the beginning
            std::reverse(slices.begin(), slices.end());
            slices.push_back(input);
            for(int idx : r_indices)
            {
                *starts_it = *dims_it - idx - 1;
                *ends_it   = *starts_it + 1;
                slices.push_back(prog.add_instruction(op::slice{axes, starts, ends}, input));
            }
            input = prog.add_instruction(op::concat{axis}, slices);
        }
        return input;
    }

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
    void check_attr_sizes(size_t kdims, size_t attr_size, const std::string& error_msg)
    {
        if(kdims != attr_size)
        {
            MIGRAPHX_THROW(error_msg + " k-dims: " + to_string(kdims) +
                           " attribute size: " + to_string(attr_size));
        }
    }

    template <class Op>
    void recalc_conv_attributes(Op& op, size_t kdims)
    {
        if(op.padding.size() != kdims)
        {
            op.padding.resize(kdims);
            std::fill_n(op.padding.begin(), kdims, 0);
        }
        if(op.stride.size() != kdims)
        {
            op.stride.resize(kdims);
            std::fill_n(op.stride.begin(), kdims, 1);
        }
        if(op.dilation.size() != kdims)
        {
            op.dilation.resize(kdims);
            std::fill_n(op.dilation.begin(), kdims, 1);
        }
    }

556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    template <class Op>
    static void cal_auto_padding_size(node_info info,
                                      Op& op,
                                      const std::vector<std::size_t>& k_lens,
                                      const std::vector<std::size_t>& dilation,
                                      const std::vector<std::size_t>& in_lens,
                                      std::vector<int64_t>& paddings)
    {
        size_t kdims = in_lens.size() - 2;
        assert(k_lens.size() == kdims and dilation.size() == kdims);

        if(!contains(info.attributes, "auto_pad"))
        {
            return;
        }

        auto auto_pad = info.attributes["auto_pad"].s();
        if(auto_pad.find("SAME") != std::string::npos)
        {
            op.padding_mode    = op::padding_mode_t::same;
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            paddings.resize(2 * kdims);

            for(size_t i = 0; i < paddings.size() / 2; i++)
            {
                calculate_padding(i,
                                  paddings,
                                  in_lens[i + 2],
                                  op.stride[i],
                                  dilation[i],
                                  k_lens[i],
                                  is_same_upper);
            }
        }
    }

    static void check_padding_mode(node_info info, const std::string& op_name)
    {
        // ensure pads availabe only when auto_pad is "NOT_SET"
        if(contains(info.attributes, "pads") and contains(info.attributes, "auto_pad"))
        {
            auto s = info.attributes["auto_pad"].s();
            if(to_upper(s) != "NOTSET")
            {
                MIGRAPHX_THROW("PARSE_" + op_name +
                               ": auto_pad and padding cannot be specified simultaneously");
            }
        }
    }

606
    template <class Op>
Paul's avatar
Paul committed
607
    instruction_ref
608
    parse_conv(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
609
    {
610
        Op op;
611
612
        auto l0      = args[0];
        auto weights = args[1];
613
614
615
616
        auto in_lens = l0->get_shape().lens();
        assert(in_lens.size() > 2);
        auto kdims = in_lens.size() - 2;

617
618
619
        // ensure pads availabe only when auto_pad is "NOT_SET"
        check_padding_mode(info, "CONV");

620
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
621
        {
622
623
624
            op.stride.clear();
            copy(info.attributes["strides"].ints(), std::back_inserter(op.stride));
            check_attr_sizes(kdims, op.stride.size(), "PARSE_CONV: inconsistent strides");
Paul's avatar
Paul committed
625
        }
626
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
627
        {
628
629
630
            op.dilation.clear();
            copy(info.attributes["dilations"].ints(), std::back_inserter(op.dilation));
            check_attr_sizes(kdims, op.dilation.size(), "PARSE_CONV: inconsistent dilations");
Paul's avatar
Paul committed
631
        }
632
633
634
635
636
637
638
639
640

        std::vector<int64_t> padding;
        if(contains(info.attributes, "pads"))
        {
            op.padding.clear();
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
            check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
        }

641
        if(contains(info.attributes, "auto_pad"))
642
        {
643
            auto weight_lens = weights->get_shape().lens();
644

645
            std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
646
            cal_auto_padding_size(info, op, k_lens, op.dilation, in_lens, padding);
647
        }
648
649
        check_asym_padding(l0, padding, op);

650
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
651
        {
652
            op.group = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
653
        }
kahmed10's avatar
kahmed10 committed
654

655
656
        recalc_conv_attributes(op, kdims);

kahmed10's avatar
kahmed10 committed
657
658
659
660
        auto l1 = prog.add_instruction(op, l0, args[1]);
        return add_bias(args, l1, 1);
    }

661
662
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
663
664
665
666
    {
        op::deconvolution op;
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
kahmed10's avatar
kahmed10 committed
667
668
669
670
671
        bool asym_padding = false;
        auto in_lens      = l0->get_shape().lens();
        assert(in_lens.size() > 2);
        auto kdims = in_lens.size() - 2;

672
673
674
        // ensure pads availabe only when auto_pad is "NOT_SET"
        check_padding_mode(info, "CONV_TRANSPOSE");

675
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
676
        {
677
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
678
679
680
681

            asym_padding = is_asym_padding(padding);

            if(not asym_padding)
kahmed10's avatar
kahmed10 committed
682
            {
kahmed10's avatar
kahmed10 committed
683
684
685
686
687
688
689
                size_t pad_ndims = padding.size() / 2;
                check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings");
                op.padding.clear();
                std::transform(padding.begin(),
                               padding.begin() + pad_ndims,
                               std::back_inserter(op.padding),
                               [](auto pad_val) { return pad_val; });
kahmed10's avatar
kahmed10 committed
690
691
            }
        }
692
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
693
        {
kahmed10's avatar
kahmed10 committed
694
695
696
            op.stride.clear();
            copy(info.attributes["strides"].ints(), std::back_inserter(op.stride));
            check_attr_sizes(kdims, op.stride.size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
kahmed10's avatar
kahmed10 committed
697
        }
698
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
699
        {
kahmed10's avatar
kahmed10 committed
700
701
702
703
            op.dilation.clear();
            copy(info.attributes["dilations"].ints(), std::back_inserter(op.dilation));
            check_attr_sizes(
                kdims, op.dilation.size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
Paul's avatar
Paul committed
704
        }
705
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
706
        {
707
708
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
709
            {
kahmed10's avatar
kahmed10 committed
710
711
                MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto_pad and padding cannot be specified "
                               "simultaneously");
kahmed10's avatar
kahmed10 committed
712
713
714
715
716
717
718
719
            }

            if(s.find("SAME") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::same;
            }
        }

720
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
721
        {
722
            op.group = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
723
724
        }

kahmed10's avatar
kahmed10 committed
725
726
        recalc_conv_attributes(op, kdims);

kahmed10's avatar
kahmed10 committed
727
728
        auto l1                   = prog.add_instruction(op, l0, args[1]);
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
kahmed10's avatar
kahmed10 committed
729
730
        std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
        if(asym_padding)
kahmed10's avatar
kahmed10 committed
731
        {
kahmed10's avatar
kahmed10 committed
732
733
734
735
736
737
738
739
740
741
742
743
            std::vector<int64_t> axes(kdims);
            std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims

            auto pad_kdim_start = padding.begin() + kdims;
            std::vector<int64_t> starts(padding.begin(), pad_kdim_start);

            std::vector<int64_t> ends{};
            std::transform(curr_shape.begin(),
                           curr_shape.end(),
                           pad_kdim_start,
                           std::back_inserter(ends),
                           [](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; });
kahmed10's avatar
kahmed10 committed
744

kahmed10's avatar
kahmed10 committed
745
            l1 = prog.add_instruction(op::slice{axes, starts, ends}, l1);
kahmed10's avatar
kahmed10 committed
746
747
        }

748
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
749
        {
kahmed10's avatar
kahmed10 committed
750
751
            size_t non_kdims = dims.size() * 2 - kdims;
            std::vector<int64_t> output_padding(non_kdims, 0);
752
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
753
754
755
756
            check_attr_sizes(kdims,
                             output_padding.size() - non_kdims,
                             "PARSE_CONV_TRANSPOSE: inconsistent output padding");
            l1 = prog.add_instruction(op::pad{output_padding}, l1);
kahmed10's avatar
kahmed10 committed
757
758
        }

759
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
760
761
        {
            std::vector<int64_t> output_shape;
762
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
763
764
765
766
            check_attr_sizes(
                kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape");
            dims = to_int64_vector(l1->get_shape().lens());
            copy(dims.begin() + 2, dims.end(), curr_shape.begin());
kahmed10's avatar
kahmed10 committed
767
768
            if(curr_shape != output_shape)
            {
kahmed10's avatar
kahmed10 committed
769
770
771
772
773
774
                std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0);
                std::transform(output_shape.begin(),
                               output_shape.end(),
                               curr_shape.begin(),
                               std::back_inserter(target_padding),
                               [](auto out_dim, auto curr_dim) { return out_dim - curr_dim; });
kahmed10's avatar
kahmed10 committed
775
776
777
778
779
                l1 = prog.add_instruction(op::pad{target_padding}, l1);
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
780
    }
Paul's avatar
Paul committed
781

782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
    static void
    tune_padding_to_symmetric(int64_t& left, int64_t& right, const int stride, int64_t& s_start)
    {
        s_start = 0;
        if(left > right)
        {
            right = left;
        }
        else if(left < right)
        {
            auto diff = right - left;
            s_start   = (diff + stride - 1) / stride;
            left      = left + s_start * stride;
            right     = left;
        }
    }

    static void tune_padding_size(const op::pooling& op,
                                  std::vector<int64_t>& padding,
                                  int count_include_pad,
                                  std::vector<int64_t>& s_start)
    {
        // maxpooling or count_include_pad is 1, no change is required.
        if(op.mode == "max" or count_include_pad == 1)
        {
            return;
        }

        // if padding is symmetric, return directly
        if(!is_asym_padding(padding))
        {
            return;
        }

        // asymmetric padding, make it symmetric
        std::size_t n_dims = padding.size() / 2;
        s_start.resize(n_dims);
        for(std::size_t i = 0; i < n_dims; ++i)
        {
            tune_padding_to_symmetric(padding[i], padding[i + n_dims], op.stride[i], s_start[i]);
        }
    }

825
826
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
827
    {
Khalique's avatar
Khalique committed
828
        op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
829
830
831
832
833
        auto l0      = args[0];
        auto in_lens = l0->get_shape().lens();
        assert(in_lens.size() > 2);
        auto kdims = in_lens.size() - 2;

Khalique's avatar
Khalique committed
834
        if(starts_with(name, "Global"))
835
        {
836
            op.lengths = std::vector<size_t>(in_lens.begin() + 2, in_lens.end());
837
        }
838

839
840
        // does not support ceil_mode
        if(contains(info.attributes, "ceil_mode"))
Paul's avatar
Paul committed
841
        {
842
            if(info.attributes.at("ceil_mode").i() == 1)
843
            {
844
                MIGRAPHX_THROW("PARSE_POOLING: pool does not support ceil_mode");
845
            }
846
        }
847

848
849
850
851
852
853
        // count include padding, if count include pad is 1, we always use
        // explicit pad
        int count_include_pad = 0;
        if(contains(info.attributes, "count_include_pad"))
        {
            count_include_pad = info.attributes.at("count_include_pad").i();
Paul's avatar
Paul committed
854
        }
855

856
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
857
        {
858
859
860
            op.stride.clear();
            copy(info.attributes["strides"].ints(), std::back_inserter(op.stride));
            check_attr_sizes(kdims, op.stride.size(), "PARSE_POOLING: inconsistent strides");
Paul's avatar
Paul committed
861
        }
862
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
863
        {
864
865
866
            op.lengths.clear();
            copy(info.attributes["kernel_shape"].ints(), std::back_inserter(op.lengths));
            check_attr_sizes(kdims, op.lengths.size(), "PARSE_POOLING: inconsistent lengths");
Paul's avatar
Paul committed
867
        }
868

869
870
871
872
873
874
875
876
877
878
879
880
881
        // ensure pads availabe only when auto_pad is "NOT_SET"
        check_padding_mode(info, "POOLING");

        std::vector<int64_t> paddings;
        float pad_val = ((op.mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
        if(contains(info.attributes, "pads"))
        {
            op.padding.clear();
            copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
            check_attr_sizes(
                kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
        }

882
        if(contains(info.attributes, "auto_pad"))
883
        {
884
            op.padding.clear();
885
886
887
            // return paddings could be empty, then setting to 0 for no padding
            cal_auto_padding_size(info, op, op.lengths, {1, 1}, in_lens, paddings);
        }
888

889
890
891
892
        if(paddings.size() != 2 * kdims)
        {
            paddings.resize(kdims * 2);
            std::fill_n(paddings.begin(), 2 * kdims, 0);
893
894
895
896
897
898
899
        }

        if(op.padding.size() != kdims)
        {
            op.padding.resize(kdims);
            std::fill_n(op.padding.begin(), kdims, 0);
        }
900

901
902
903
904
905
        if(op.stride.size() != kdims)
        {
            op.stride.resize(kdims);
            std::fill_n(op.stride.begin(), kdims, 1);
        }
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
        // used to calculate the supposed output shape
        std::vector<int64_t> orig_padding(paddings.begin(), paddings.end());

        std::vector<int64_t> slice_start;
        std::vector<int64_t> slice_end;
        tune_padding_size(op, paddings, count_include_pad, slice_start);

        if(!slice_start.empty())
        {
            // calculate expected output shape
            orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
            orig_padding.insert(orig_padding.begin(), 2, 0);
            op::pad pad{orig_padding, 0.0f};
            shape padded_shape = pad.compute_shape({l0->get_shape()});
            auto out_lens      = op.compute_shape({padded_shape}).lens();
921

922
923
924
925
926
927
928
929
930
931
932
            // compute slice_end information
            slice_end.resize(slice_start.size());
            std::transform(out_lens.begin() + 2,
                           out_lens.end(),
                           slice_start.begin(),
                           slice_end.begin(),
                           [](auto i, auto j) { return i + j; });
        }

        check_asym_padding(l0, paddings, op, count_include_pad, pad_val);
        in_lens = l0->get_shape().lens();
933
934
935
        for(size_t i = 0; i < kdims; i++)
        {
            if(op.lengths[i] > in_lens[i + 2] + 2 * op.padding[i])
936
            {
937
                MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large");
938
939
940
941
942
943
944
945
946
            }
        }

        auto l1 = prog.add_instruction(op, l0);
        if(!slice_start.empty())
        {
            std::vector<int64_t> axes(kdims);
            std::iota(axes.begin(), axes.end(), 2);
            l1 = prog.add_instruction(op::slice{axes, slice_start, slice_end}, l1);
947
948
        }

949
        return l1;
Paul's avatar
Paul committed
950
951
    }

Paul's avatar
Paul committed
952
    instruction_ref
953
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
954
    {
955
        op::reshape op;
Paul's avatar
Paul committed
956
957
        if(args.size() == 1)
        {
958
            literal s = parse_value(info.attributes.at("shape"));
959
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
960
961
962
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
963
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
964
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
965
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
966
        }
967

Shucai Xiao's avatar
Shucai Xiao committed
968
        return prog.add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
969
970
    }

Paul's avatar
Paul committed
971
    instruction_ref
972
    parse_flatten(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
973
    {
974
        int64_t axis = 1;
975
        if(contains(info.attributes, "axis"))
Paul's avatar
Paul committed
976
        {
977
            axis = parse_value(info.attributes.at("axis")).at<int>();
Paul's avatar
Paul committed
978
        }
979
        return prog.add_instruction(op::flatten{axis}, args[0]);
Paul's avatar
Paul committed
980
981
    }

982
    instruction_ref
983
    parse_squeeze(const std::string&, node_info info, std::vector<instruction_ref> args)
984
985
    {
        op::squeeze op;
986
        literal s = parse_value(info.attributes.at("axes"));
987
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
988
        return prog.add_instruction(op, make_contiguous(args[0]));
989
990
991
    }

    instruction_ref
992
    parse_unsqueeze(const std::string&, node_info info, std::vector<instruction_ref> args)
993
994
    {
        op::unsqueeze op;
995
        literal s = parse_value(info.attributes.at("axes"));
996
        s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
997
        return prog.add_instruction(op, make_contiguous(args[0]));
998
999
    }

Scott Thornton's avatar
Scott Thornton committed
1000
    instruction_ref
1001
    parse_concat(const std::string&, node_info info, std::vector<instruction_ref> args)
Scott Thornton's avatar
Scott Thornton committed
1002
    {
Shucai Xiao's avatar
Shucai Xiao committed
1003
        // change to hande axis to be negative values
1004
        if(!contains(info.attributes, "axis"))
Shucai Xiao's avatar
Shucai Xiao committed
1005
1006
1007
1008
        {
            MIGRAPHX_THROW("PARSE_CONCAT: attribute axis is required!");
        }

1009
        int axis = parse_value(info.attributes.at("axis")).at<int>();
Scott Thornton's avatar
Scott Thornton committed
1010
1011
1012
        op::concat op{axis};
        return prog.add_instruction(op, std::move(args));
    }
1013

1014
    instruction_ref
1015
    parse_gather(const std::string&, node_info info, std::vector<instruction_ref> args)
1016
    {
1017
        int axis = 0;
1018
        if(contains(info.attributes, "axis"))
1019
        {
1020
            axis = parse_value(info.attributes.at("axis")).at<int>();
1021
        }
1022

1023
        op::gather op{axis};
Shucai Xiao's avatar
Shucai Xiao committed
1024
        return prog.add_instruction(op, make_contiguous(args[0]), make_contiguous(args[1]));
1025
1026
    }

Shucai Xiao's avatar
Shucai Xiao committed
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
    instruction_ref
    parse_gather_elements(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        // standardize input data and index
        auto arg_data = make_contiguous(args[0]);
        auto arg_ind  = make_contiguous(args[1]);

        auto data_s = arg_data->get_shape();
        auto ind_s  = arg_ind->get_shape();

        if(data_s.lens().size() != ind_s.lens().size())
        {
            MIGRAPHX_THROW("PARSE_GATHER_ELEMENTS: input data and index must have the same rank!");
        }

        int n_rank     = static_cast<int>(data_s.lens().size());
        int tuned_axis = (axis < 0) ? (axis + n_rank) : axis;

        auto axis_stride      = data_s.strides()[tuned_axis];
        int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
        // reshape the input data as one dimension and used as input data
        // to the gather operator
        arg_data = prog.add_instruction(op::reshape{{data_elem_num}}, arg_data);

        std::size_t elem_num = ind_s.elements();
        std::vector<int> ind_index(elem_num);
        std::iota(ind_index.begin(), ind_index.end(), 0);

        // convert index in input indices to that in input data
        std::vector<int> data_indices(elem_num);
        std::transform(ind_index.begin(), ind_index.end(), data_indices.begin(), [&](auto i) {
            return data_s.index(ind_s.multi(i));
        });

        std::vector<int> vec_axis_ind(elem_num);
        std::transform(ind_index.begin(), ind_index.end(), vec_axis_ind.begin(), [&](auto i) {
            return ind_s.multi(i)[tuned_axis];
        });

        auto l_shape_idx =
            prog.add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
        auto l_dim_idx = prog.add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
        auto l_stride  = prog.add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
        l_stride       = prog.add_instruction(op::multibroadcast{ind_s.lens()}, l_stride);
        auto dim_diff  = prog.add_instruction(op::sub{}, arg_ind, l_dim_idx);
        auto delta     = prog.add_instruction(op::mul{}, dim_diff, l_stride);
        auto ind       = prog.add_instruction(op::add{}, l_shape_idx, delta);

        op::gather op{0};
        return prog.add_instruction(op, arg_data, ind);
    }

1085
    instruction_ref
1086
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
1087
1088
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
1111
        {
1112
            literal s = parse_value(info.attributes.at("axes"));
1113
1114
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
1115
1116

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
1117
        {
Shucai Xiao's avatar
Shucai Xiao committed
1118
1119
1120
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
1121
        }
Shucai Xiao's avatar
Shucai Xiao committed
1122
        else if(contains(info.attributes, "ends"))
1123
        {
1124
1125
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
1126
        }
Shucai Xiao's avatar
Shucai Xiao committed
1127
1128
1129
1130
1131
1132
1133
1134

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
1135
        {
1136
            literal s = parse_value(info.attributes.at("starts"));
1137
1138
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
1139

kahmed10's avatar
kahmed10 committed
1140
1141
1142
1143
1144
1145
1146
        if(op.axes.empty())
        {
            std::vector<int64_t> axes(args[0]->get_shape().lens().size());
            std::iota(axes.begin(), axes.end(), int64_t{0});
            op.axes = axes;
        }

1147
1148
1149
        return prog.add_instruction(op, args[0]);
    }

1150
1151
    instruction_ref
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&)
Paul's avatar
Paul committed
1152
    {
1153
        literal v = parse_value(info.attributes.at("value"));
1154
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
1155
        if(v.get_shape().elements() == 0)
1156
1157
1158
1159
        {
            return prog.add_literal(literal{});
        }

1160
        auto dim_size = info.attributes.at("value").t().dims_size();
1161
1162
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
1163
        {
1164
            migraphx::shape scalar_shape{v.get_shape().type()};
1165
1166
1167
            return prog.add_literal(migraphx::literal{scalar_shape, v.data()});
        }

Paul's avatar
Paul committed
1168
1169
        return prog.add_literal(v);
    }
Paul's avatar
Paul committed
1170

Paul's avatar
Paul committed
1171
    instruction_ref
1172
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
1173
1174
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
1175
        float beta  = 1.0f;
Paul's avatar
Paul committed
1176
1177
        bool transa = false;
        bool transb = false;
1178
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
1179
        {
1180
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
1181
        }
1182
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
1183
        {
1184
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
1185
        }
1186
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
1187
        {
1188
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
1189
        }
1190
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
1191
        {
1192
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
1193
        }
1194
1195
1196
1197
1198
1199

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

1200
1201
        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
1202
1203
        if(args.size() == 3)
        {
1204
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
1205
            {
Shucai Xiao's avatar
Shucai Xiao committed
1206
                auto out_lens   = l1->get_shape().lens();
1207
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
1208
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
1209
1210
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
1211
                {
1212
                    l3 = prog.add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
1213
                }
1214
                return prog.add_instruction(op::dot{alpha, beta}, l1, l2, l3);
1215
            }
Paul's avatar
Paul committed
1216
        }
1217
1218

        return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
Paul's avatar
Paul committed
1219
1220
    }

1221
    template <class Op>
1222
    instruction_ref
1223
    parse_matmul(const std::string&, const node_info&, std::vector<instruction_ref> args)
1224
    {
Shucai Xiao's avatar
Shucai Xiao committed
1225
1226
        auto l0      = args[0];
        auto l1      = args[1];
1227
1228
1229
1230
1231
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
1232
        if(l0_lens.size() == 1)
1233
1234
1235
1236
1237
1238
1239
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
            l0 = prog.add_instruction(op::unsqueeze{{0}}, args[0]);
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
1240
        if(l1_lens.size() == 1)
1241
1242
1243
1244
1245
1246
1247
1248
        {
            is_b_appended = true;
            l1_lens.push_back(1);
            l1 = prog.add_instruction(op::unsqueeze{{1}}, args[1]);
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
1249
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
1250
1251
1252
1253
1254
1255
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
1256
            l0_broadcasted_lens = output_lens;
1257
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
1258
            l1_broadcasted_lens = output_lens;
1259
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
1260
            if(l0_lens != l0_broadcasted_lens)
1261
1262
1263
            {
                bl0 = prog.add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
            }
Shucai Xiao's avatar
Shucai Xiao committed
1264
            if(l1_lens != l1_broadcasted_lens)
1265
1266
1267
1268
1269
            {
                bl1 = prog.add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
            }
        }

1270
        auto dot_res     = prog.add_instruction(Op{1, 0}, bl0, bl1);
1271
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
1272
        if(is_a_prepended)
1273
1274
1275
1276
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
1277
        if(is_b_appended)
1278
1279
1280
        {
            dot_res = prog.add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
        }
Shucai Xiao's avatar
Shucai Xiao committed
1281

1282
1283
1284
        return dot_res;
    }

1285
    instruction_ref
1286
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args)
1287
    {
Scott Thornton's avatar
Scott Thornton committed
1288
1289
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
1290
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
1291
        if(contains(info.attributes, "epsilon"))
1292
        {
1293
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
1294
        }
1295
        if(contains(info.attributes, "momentum"))
1296
        {
1297
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
1298
        }
1299
        if(contains(info.attributes, "spatial"))
1300
        {
1301
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
1302
1303
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
1304
        }
Paul's avatar
Paul committed
1305
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
Paul's avatar
Paul committed
1306
        return prog.add_instruction(op, std::move(args));
1307
1308
    }

1309
1310
    instruction_ref
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
1311
1312
1313
1314
1315
1316
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
        // mean = reduce_mean({H, W}, x)
        // variance = reduce_mean({H, W}, (x - mean)^2)

        float epsilon = 1e-5f;
1317
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
1318
        {
1319
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();

        auto mean            = prog.add_instruction(op::reduce_mean{{2, 3}}, x);
        auto mean_bcast      = prog.add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = prog.add_instruction(op::sqdiff{}, x, mean_bcast);
        auto variance        = prog.add_instruction(op::reduce_mean{{2, 3}}, l0);
        auto l1              = prog.add_instruction(op::sub{}, x, mean_bcast);
        auto epsilon_literal = prog.add_literal(epsilon);
        auto epsilon_bcast   = prog.add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = prog.add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = prog.add_instruction(op::add{}, variance_bcast, epsilon_bcast);
        auto l3              = prog.add_instruction(op::rsqrt{}, l2);
        auto l4              = prog.add_instruction(op::mul{}, l1, l3);
        auto scale_bcast     = prog.add_instruction(op::broadcast{1, dims}, scale);
        ;
        auto bias_bcast = prog.add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = prog.add_instruction(op::mul{}, l4, scale_bcast);
        return prog.add_instruction(op::add{}, l5, bias_bcast);
    }

1344
1345
    instruction_ref
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args)
1346
    {
Khalique's avatar
Khalique committed
1347
        float alpha = 0.01; // default alpha val for leaky relu
1348
        if(contains(info.attributes, "alpha"))
1349
        {
1350
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1351
1352
1353
1354
1355
        }
        op::leaky_relu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1356
    instruction_ref parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1357
1358
    {
        float alpha = 1.0; // default alpha val for elu
1359
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1360
        {
1361
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1362
1363
1364
1365
1366
        }
        op::elu op{alpha};
        return prog.add_instruction(op, args.front());
    }

1367
    instruction_ref parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1368
1369
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1370
1371
1372
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1373
1374
1375
1376
1377
1378
1379
1380
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1381
1382
1383
1384
        op::lrn op{alpha, beta, bias, size};
        return prog.add_instruction(op, args.front());
    }

1385
1386
    instruction_ref
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1387
1388
1389
    {
        float scale = 1.0;
        std::vector<float> bias{};
1390
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1391
        {
1392
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1393
1394
        }

1395
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1396
        {
1397
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1398
1399
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1400
1401
1402
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1403

Shucai Xiao's avatar
Shucai Xiao committed
1404
1405
        auto scale_val = prog.add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = prog.add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1406

1407
        auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_lens}, scale_val);
Paul's avatar
Paul committed
1408
        auto img_scaled   = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
Shucai Xiao's avatar
Shucai Xiao committed
1409
        auto bias_bcast   = prog.add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
Paul's avatar
Paul committed
1410
        return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1411
    }
Khalique's avatar
Khalique committed
1412

Khalique's avatar
Khalique committed
1413
    instruction_ref
1414
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1415
1416
    {
        std::vector<int64_t> perm{};
1417
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1418
        {
1419
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1420
1421
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
Paul's avatar
Paul committed
1422
        return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1423
1424
    }

1425
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1426
1427
    {
        std::vector<int64_t> pads{};
1428
1429
1430
1431
1432
1433
1434
        if(args.size() >= 2)
        {
            auto pad_arg = args.at(1)->eval();
            check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
            pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
        }
        else if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1435
        {
1436
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1437
1438
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1439
1440
1441
1442
1443
        else
        {
            MIGRAPHX_THROW("PARSE_PAD: pad must be available");
        }

1444
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1445
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1446
1447
1448
        {
            return prog.add_instruction(migraphx::op::identity{}, args.front());
        }
1449

kahmed10's avatar
kahmed10 committed
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
        if(contains(info.attributes, "mode"))
        {
            auto mode = info.attributes.at("mode").s();
            if(mode == "reflect")
                return reflect_pad(pads, args.front());
            if(mode != "constant")
            {
                MIGRAPHX_THROW(
                    "PARSE_PAD: migraphx currently only supports constant and reflect padding");
            }
        }

1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
        float value = 0.0f;
        // third input is the value
        if(args.size() == 3)
        {
            auto val_ins = args.at(2);
            if(!val_ins->can_eval())
            {
                MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
            }
            auto val_arg = val_ins->eval();
            if(val_arg.get_shape().elements() != 1)
            {
                MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
            }
            value = val_arg.at<float>();
        }
        else if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1479
        {
1480
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1481
        }
1482

Khalique's avatar
Khalique committed
1483
1484
        return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
    }
1485
1486
1487
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1488
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args)
1489
1490
    {
        if(args.size() != 1)
1491
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
        return prog.add_literal(migraphx::literal{s, vec_shape});
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1504
1505
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1506
1507
1508
1509
1510
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1511
        if(contains(info.attributes, "dtype"))
1512
        {
1513
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1514
        }
Shucai Xiao's avatar
Shucai Xiao committed
1515
        shape::type_t type = get_type(dtype);
1516

1517
        if(contains(info.attributes, "input_as_shape"))
1518
        {
1519
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1520
1521
        }

1522
        if(contains(info.attributes, "value"))
1523
        {
1524
            value = parse_value(info.attributes.at("value")).at<float>();
1525
1526
        }

1527
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1528
        {
1529
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1530
1531
        }

1532
1533
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1534
            if(args.size() != 1)
1535
            {
1536
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1537
1538
            }

1539
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1540
            {
1541
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1542
                               "at the same time");
1543
1544
            }

1545
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1546
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1547

1548
1549
1550
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1551
1552
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1553
1554
1555
        }
        else if(input_as_shape == 0)
        {
1556
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1557
            {
1558
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1559
1560
            }

1561
            literal ls = parse_value(info.attributes.at("shape"));
1562
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1563
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1564
            migraphx::shape s{type, dims};
1565
1566
            std::vector<float> values(s.elements(), value);
            return prog.add_literal(migraphx::literal(s, values));
1567
1568
1569
        }
        else
        {
1570
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1571
1572
1573
        }
    }

1574
1575
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1576
1577
    {
        literal l_val{};
1578
        if(contains(info.attributes, "value"))
1579
        {
1580
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1581
            if(l_val.get_shape().elements() != 1)
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1593

Shucai Xiao's avatar
Shucai Xiao committed
1594
        if(args.empty())
1595
        {
Shucai Xiao's avatar
Shucai Xiao committed
1596
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1597
1598
1599
        }
        else
        {
1600
1601
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1602
            if(args[0]->get_shape().elements() == 0)
1603
            {
1604
                s = migraphx::shape{type, {1}, {0}};
1605
            }
1606
1607
1608
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1609
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1610

1611
1612
1613
1614
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1615

Shucai Xiao's avatar
Shucai Xiao committed
1616
            literal l_out{};
1617
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1618
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1619
                // l_val contains only one element
1620
                std::vector<val_type> out_vec(s.elements(), val.front());
1621
1622
1623
1624
1625
1626
1627
                l_out = literal(s, out_vec);
            });

            return prog.add_literal(l_out);
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1628
    instruction_ref
1629
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1630
    {
Shucai Xiao's avatar
Shucai Xiao committed
1631
        auto in_lens             = args[0]->get_shape().lens();
1632
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1633
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1634
1635
1636
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
Shucai Xiao's avatar
Shucai Xiao committed
1637
        return prog.add_instruction(op::multibroadcast{out_lens}, args[0]);
1638
1639
    }

Shucai Xiao's avatar
Shucai Xiao committed
1640
    std::vector<instruction_ref>
1641
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1642
1643
    {
        migraphx::shape input_shape = args[0]->get_shape();
1644
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1645

1646
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1647
        {
1648
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1649
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1650
1651
1652
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1653
1654
1655
1656
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1657
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1658
        {
1659
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1660
1661
        }

1662
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1663
1664
        if(direction == "bidirectional")
        {
1665
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1666
1667
1668
        }
        else if(direction == "reverse")
        {
1669
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1670
1671
        }

1672
        std::vector<std::string> vec_names{"tanh"};
1673
        if(contains(info.attributes, "activations"))
1674
        {
1675
            auto names = info.attributes.at("activations").strings();
1676
            vec_names.clear();
1677
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1678
1679
1680
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1681
1682
        }

1683
1684
1685
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1686
        if(name_it != vec_names.end())
1687
1688
1689
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1690

Shucai Xiao's avatar
Shucai Xiao committed
1691
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1692
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1693
        // if only one actv function is provided, we use it in both
1694
        // forward and reverse direction
1695
        if(dirct == op::rnn_direction::bidirectional)
1696
        {
Shucai Xiao's avatar
Shucai Xiao committed
1697
            if(vec_names.size() == 1)
1698
1699
1700
1701
1702
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1703
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1704
1705
1706
1707
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1708

Shucai Xiao's avatar
Shucai Xiao committed
1709
1710
        // To be added later
        float clip = 0.0;
1711
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1712
        {
1713
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1714
1715
        }

1716
1717
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1718
        if(args.size() < 6)
1719
1720
1721
1722
1723
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
1724
1725
        // first output for the concatenation of hidden states
        auto hidden_states = prog.add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip},
Shucai Xiao's avatar
Shucai Xiao committed
1726
                                                  std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
1727

1728
        // second output for the last hidden state
Shucai Xiao's avatar
Shucai Xiao committed
1729
        auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
1730

Shucai Xiao's avatar
Shucai Xiao committed
1731
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
1732
1733
    }

1734
    std::vector<instruction_ref>
1735
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
1736
1737
1738
1739
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1740
        if(contains(info.attributes, "hidden_size"))
1741
        {
1742
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1743
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1744
1745
1746
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
1747
1748
1749
1750
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1751
        if(contains(info.attributes, "direction"))
1752
        {
1753
            direction = info.attributes.at("direction").s();
1754
1755
        }

1756
        op::rnn_direction dirct = op::rnn_direction::forward;
1757
1758
        if(direction == "bidirectional")
        {
1759
            dirct = op::rnn_direction::bidirectional;
1760
1761
1762
        }
        else if(direction == "reverse")
        {
1763
            dirct = op::rnn_direction::reverse;
1764
1765
        }

1766
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
1767
        if(contains(info.attributes, "activations"))
1768
        {
1769
            auto names = info.attributes.at("activations").strings();
1770
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
1771
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1772
1773
1774
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1775
1776
        }

1777
        // need 4 activation functions
1778
        if(dirct == op::rnn_direction::bidirectional)
1779
        {
Shucai Xiao's avatar
Shucai Xiao committed
1780
            // 4 activation functions are used in the bidirectional
1781
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
1782
1783
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
1784
1785
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
1786
1787
1788
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
1789
            if(vec_names.size() == 1)
1790
            {
1791
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
1792
            }
1793
            else if(vec_names.size() == 2)
1794
            {
1795
1796
1797
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
1798
            }
1799
            else if(vec_names.size() == 3)
1800
            {
1801
                vec_names.push_back(vec_names.at(2));
1802
1803
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
1804
        else
1805
        {
1806
            if(vec_names.size() == 1)
1807
            {
1808
                vec_names.push_back(vec_names.at(0));
1809
1810
1811
            }
        }

1812
1813
1814
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1815
        if(name_it != vec_names.end())
1816
1817
1818
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
1819

Shucai Xiao's avatar
Shucai Xiao committed
1820
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1821
1822
1823
1824
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
1825
1826

        float clip = 0.0;
1827
        if(contains(info.attributes, "clip"))
1828
        {
1829
            clip = parse_value(info.attributes.at("clip")).at<float>();
1830
1831
1832
        }

        int linear_before_reset = 0;
1833
        if(contains(info.attributes, "linear_before_reset"))
1834
        {
1835
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
1836
1837
        }

Shucai Xiao's avatar
Shucai Xiao committed
1838
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
1839
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
1840
1841
1842
1843
1844
        {
            auto ins = prog.add_instruction(op::undefined{});
            args.insert(args.end(), 6 - args.size(), ins);
        }

1845
1846
        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
1847
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
1848
            std::move(args));
1849
1850

        // second output for last gru output
Shucai Xiao's avatar
Shucai Xiao committed
1851
        auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
1852

Shucai Xiao's avatar
Shucai Xiao committed
1853
        return {hidden_states, last_output};
1854
1855
    }

Shucai Xiao's avatar
Shucai Xiao committed
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
    void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv_func_names)
    {
        // need 6 activation functions for bidirectional directions
        if(dirct == op::rnn_direction::bidirectional)
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
            // if 3 actv funcs are provide, repeat all three once.
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(actv_func_names.size())
            {
            case 1:
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(0),
                                   actv_func_names.at(0),
                                   actv_func_names.at(0),
                                   actv_func_names.at(0),
                                   actv_func_names.at(0)};
                break;

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(1),
                                   actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(1)};
                break;

            case 3:
                // repeat all three actv funcs once
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(2),
                                   actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(2)};
                break;

            case 4:
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(2),
                                   actv_func_names.at(3),
                                   actv_func_names.at(3),
                                   actv_func_names.at(3)};
                break;

            case 5:
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(2),
                                   actv_func_names.at(3),
                                   actv_func_names.at(4),
                                   actv_func_names.at(4)};
                break;

            default: break;
            }
        }
        else
        {
            switch(actv_func_names.size())
            {
            case 1:
                actv_func_names = {
                    actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0)};
                break;

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
                actv_func_names = {
                    actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(1)};
                break;

            default: break;
            }
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1941
    std::vector<instruction_ref>
1942
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1943
1944
1945
1946
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

1947
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1948
        {
1949
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1950
1951
1952
1953
1954
1955
1956
1957
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1958
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1959
        {
1960
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1961
1962
        }

Shucai Xiao's avatar
Shucai Xiao committed
1963
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1964
1965
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1966
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1967
1968
1969
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
1970
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1971
        }
Shucai Xiao's avatar
Shucai Xiao committed
1972
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
1973
        {
Shucai Xiao's avatar
Shucai Xiao committed
1974
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1975
1976
1977
1978
1979
1980
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

1981
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
1982
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
1983
        {
1984
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
1985
1986
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1987
1988
1989
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
1990
1991
        }

Shucai Xiao's avatar
Shucai Xiao committed
1992
        lstm_actv_functions(dirct, vec_names);
Shucai Xiao's avatar
Shucai Xiao committed
1993

1994
1995
1996
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1997
        if(name_it != vec_names.end())
1998
1999
2000
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
2001
2002

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
2003
2004
2005
2006
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
2007
2008

        float clip = 0.0;
2009
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
2010
        {
2011
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
2012
2013
2014
        }

        int input_forget = 0;
2015
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
2016
        {
2017
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
2018
2019
2020
2021
2022
2023
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
            auto ins = prog.add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
2024
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
2025
2026
2027
2028
        }

        // first output for concatenation of hidden states
        auto hidden_states = prog.add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
2029
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
2030

Shucai Xiao's avatar
Shucai Xiao committed
2031
        auto last_output = prog.add_instruction(op::rnn_last_hs_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
2032
2033

        // third output for last cell output
Shucai Xiao's avatar
Shucai Xiao committed
2034
        auto last_cell_output = prog.add_instruction(op::rnn_last_cell_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
2035
2036
2037

        return {hidden_states, last_output, last_cell_output};
    }
2038

Shucai Xiao's avatar
Shucai Xiao committed
2039
    template <class T>
2040
2041
    instruction_ref
    parse_reduce_oper(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
2042
2043
2044
2045
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
2046
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
2047
        std::iota(axes.begin(), axes.end(), 0);
2048
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
2049
2050
        {
            axes.clear();
2051
            auto&& attr_axes = info.attributes["axes"].ints();
2052
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
2053
2054
2055
        }

        int keep_dims = 1;
2056
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
2057
        {
2058
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
2059
2060
2061
2062
        }

        if(keep_dims == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
2063
            return prog.add_instruction(T{axes}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
2064
2065
2066
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
2067
            auto ins = prog.add_instruction(T{axes}, std::move(args));
2068
            return prog.add_instruction(op::squeeze{axes}, ins);
2069
2070
        }
    }
2071

Shucai Xiao's avatar
Shucai Xiao committed
2072
    instruction_ref
2073
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
2074
2075
    {
        auto abs_ins = prog.add_instruction(op::abs{}, args[0]);
2076
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
2077
2078
2079
    }

    instruction_ref
2080
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
2081
2082
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
2083
        auto sum_ins    = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
2084
2085
2086
        return prog.add_instruction(op::sqrt{}, sum_ins);
    }

2087
2088
    instruction_ref
    parse_reduce_log_sum(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
2089
    {
2090
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
2091
2092
2093
        return prog.add_instruction(op::log{}, sum_ins);
    }

2094
2095
    instruction_ref
    parse_reduce_log_sum_exp(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
2096
2097
    {
        auto exp_ins = prog.add_instruction(op::exp{}, args[0]);
2098
        auto sum_ins = parse_reduce_oper<op::reduce_sum>({}, std::move(info), {exp_ins});
Shucai Xiao's avatar
Shucai Xiao committed
2099
2100
2101
        return prog.add_instruction(op::log{}, sum_ins);
    }

2102
2103
    instruction_ref
    parse_reduce_sum_square(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
2104
2105
    {
        auto square_ins = prog.add_instruction(op::mul{}, args[0], args[0]);
2106
        return parse_reduce_oper<op::reduce_sum>({}, std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
2107
2108
    }

Shucai Xiao's avatar
Shucai Xiao committed
2109
    instruction_ref
2110
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args)
2111
    {
2112
        if(!contains(info.attributes, "to"))
2113
2114
2115
2116
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

2117
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
2118
2119
2120
        shape::type_t type = get_type(to_type);
        return prog.add_instruction(op::convert{type}, std::move(args));
    }
Shucai Xiao's avatar
Shucai Xiao committed
2121

2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
                prog.add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
            start += sl;
        }

        return ret_ins;
    }

kahmed10's avatar
kahmed10 committed
2175
2176
2177
2178
    instruction_ref
    parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        migraphx::argument depth_arg = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
2179
        check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported");
kahmed10's avatar
kahmed10 committed
2180
2181
2182
        size_t depth = depth_arg.at<size_t>();

        int64_t axis = -1;
Shucai Xiao's avatar
Shucai Xiao committed
2183
2184
2185
2186
        if(contains(info.attributes, "axis"))
        {
            axis = info.attributes.at("axis").i();
        }
kahmed10's avatar
kahmed10 committed
2187

Shucai Xiao's avatar
Shucai Xiao committed
2188
        std::vector<float> depth_input(depth * depth, 0.0f);
kahmed10's avatar
kahmed10 committed
2189
2190
        for(int i = 0; i < depth; i++)
        {
Shucai Xiao's avatar
Shucai Xiao committed
2191
            depth_input[depth * i + i] = 1.0f;
kahmed10's avatar
kahmed10 committed
2192
2193
        }

Shucai Xiao's avatar
Shucai Xiao committed
2194
2195
2196
2197
2198
2199
2200
2201
        auto type = args[2]->get_shape().type();
        shape s{type, {depth, depth}};
        auto l_val      = prog.add_literal({s, depth_input});
        auto gather_out = prog.add_instruction(op::gather{0}, {l_val, args[0]});

        // Finally, we need a transpose to move the inner most dim to the axis dim
        int n_rank = gather_out->get_shape().lens().size();
        if(axis < -n_rank or axis >= n_rank)
kahmed10's avatar
kahmed10 committed
2202
        {
Shucai Xiao's avatar
Shucai Xiao committed
2203
            MIGRAPHX_THROW("PARSE_ONEHOT: axis out of range");
kahmed10's avatar
kahmed10 committed
2204
        }
Shucai Xiao's avatar
Shucai Xiao committed
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
        std::vector<int64_t> perm(n_rank - 1);
        std::iota(perm.begin(), perm.end(), 0);
        perm.insert(perm.begin() + tuned_axis, n_rank - 1);
        auto tr_out = prog.add_instruction(op::transpose{perm}, gather_out);
        auto lens   = tr_out->get_shape().lens();

        auto off_val       = prog.add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
        auto on_val        = prog.add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
        auto diff          = prog.add_instruction(op::sub{}, on_val, off_val);
        auto unsq_off_val  = prog.add_instruction(op::multibroadcast{lens}, off_val);
        auto unsq_diff_val = prog.add_instruction(op::multibroadcast{lens}, diff);
        auto l_mul         = prog.add_instruction(op::mul{}, tr_out, unsq_diff_val);
        return prog.add_instruction(op::add{}, l_mul, unsq_off_val);
kahmed10's avatar
kahmed10 committed
2219
2220
    }

kahmed10's avatar
kahmed10 committed
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
    instruction_ref
    parse_tile(const std::string&, const node_info&, std::vector<instruction_ref> args)
    {
        migraphx::argument arg_s = args[1]->eval();
        check_arg_empty(arg_s, "PARSE_TILE: dynamic shape is not supported");
        std::vector<std::int64_t> repeats;
        arg_s.visit([&](auto input) { repeats.assign(input.begin(), input.end()); });

        auto l0 = args[0];
        for(int i = 0; i < repeats.size(); i++)
        {
            auto l1 = l0;
            for(int j = 1; j < repeats[i]; j++)
            {
                l0 = prog.add_instruction(op::concat{i}, l0, l1);
            }
        }
        return l0;
    }

kahmed10's avatar
kahmed10 committed
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
    instruction_ref
    parse_range(const std::string&, const node_info&, std::vector<instruction_ref> args)
    {

        auto start_arg = args[0]->eval();
        check_arg_empty(start_arg, "PARSE_RANGE: start arg dynamic shape is not supported");
        auto limit_arg = args[1]->eval();
        check_arg_empty(limit_arg, "PARSE_RANGE: limit arg dynamic shape is not supported");
        auto delta_arg = args[2]->eval();
        check_arg_empty(delta_arg, "PARSE_RANGE: delta arg dynamic shape is not supported");

        assert(args[0]->get_shape().elements() == 1 and args[1]->get_shape().elements() == 1 and
               args[2]->get_shape().elements() == 1);

        instruction_ref l0;

        visit_all(start_arg, limit_arg, delta_arg)([&](auto start, auto limit, auto delta) {
            auto start_val = start.front();
            auto limit_val = limit.front();
            auto delta_val = delta.front();

            size_t num_elements = static_cast<size_t>(
                ceil(static_cast<double>(limit_val - start_val) / static_cast<double>(delta_val)));

            assert(num_elements > 0);

            using type = decltype(start_val);

            std::vector<type> range_vals(num_elements);

            std::generate(range_vals.begin(), range_vals.end(), [&]() {
                auto result = start_val;
                start_val += delta_val;
                return result;
            });

            l0 = prog.add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals});
        });
        return l0;
    }

2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
    enum class reduce_mode_t
    {
        sum  = 0,
        mean = 1,
        max  = 2
    };

    instruction_ref parse_embedding_bag(const node_info& info, std::vector<instruction_ref> args)
    {
        if(args[2]->get_shape().elements() != 1)
            MIGRAPHX_THROW("PARSE_EMBEDDING_BAG: MIGraphX only supports offsets of size 1");
        reduce_mode_t reduce_mode = reduce_mode_t::sum;
        if(contains(info.attributes, "mode"))
        {
            reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i());
        }

        auto l0 = prog.add_instruction(op::gather{}, args[0], args[1]);
        switch(reduce_mode)
        {
        case reduce_mode_t::sum: l0 = prog.add_instruction(op::reduce_sum{{0}}, l0); break;
        case reduce_mode_t::mean: l0 = prog.add_instruction(op::reduce_mean{{0}}, l0); break;
        case reduce_mode_t::max: l0 = prog.add_instruction(op::reduce_max{{0}}, l0); break;
        }
        return l0;
    }

    instruction_ref
    parse_aten(const std::string&, const node_info& info, std::vector<instruction_ref> args)
    {
        if(contains(info.attributes, "operator"))
        {
            auto op_name = info.attributes.at("operator").s();
            if(op_name.find("embedding_bag") != std::string::npos)
            {
                return parse_embedding_bag(info, std::move(args));
            }
        }
        MIGRAPHX_THROW("PARSE_ATEN: unsupported custom operator");
    }

Paul's avatar
Paul committed
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
    void parse_from(std::istream& is)
    {
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
2335
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
2336
2337
2338
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
2355
2356
    void parse_graph(const onnx::GraphProto& graph)
    {
2357
        for(auto&& f : graph.initializer())
2358
2359
            instructions[f.name()] = prog.add_literal(parse_tensor(f));

Paul's avatar
Paul committed
2360
2361
2362
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
2363
2364
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
2365
            {
2366
2367
2368
2369
2370
2371
2372
                std::vector<std::size_t> dims;
                if(map_input_dims.count(name) > 0)
                {
                    dims = map_input_dims.at(name);
                }

                shape s            = parse_type(input.type(), dims);
2373
2374
                instructions[name] = prog.add_parameter(name, s);
            }
Paul's avatar
Paul committed
2375
        }
2376
2377

        for(auto&& node : graph.node())
Paul's avatar
Paul committed
2378
        {
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(input.empty())
                {
                    this->parse_undefined(input);
                }
                if(instructions.count(input) == 0)
                {
                    MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
                                   "\" is unavailable due to unordered nodes!");
                }
                args.push_back(instructions.at(input));
            }

            std::vector<instruction_ref> result;
            std::size_t output_num = static_cast<std::size_t>(node.output().size());
            if(ops.count(node.op_type()) == 0)
            {
2398
2399
2400
2401
                if(skip_unknown_operators)
                    result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
                else
                    MIGRAPHX_THROW("Unknown operator: " + node.op_type());
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
            }
            else
            {
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
            }

            output_num = std::min<std::size_t>(output_num, result.size());
            std::transform(node.output().begin(),
                           node.output().begin() + output_num,
                           result.begin(),
                           std::inserter(instructions, instructions.end()),
                           [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
2414
        }
Shucai Xiao's avatar
Shucai Xiao committed
2415

2416
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
2417
        auto prog_output = graph.output();
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
        prog.add_return(output_ins);
Paul's avatar
Paul committed
2438
2439
    }

Shucai Xiao's avatar
Shucai Xiao committed
2440
    void parse_undefined(const std::string& name)
2441
    {
Shucai Xiao's avatar
Shucai Xiao committed
2442
        auto ins           = prog.add_instruction(op::undefined{});
2443
2444
2445
        instructions[name] = ins;
    }

Paul's avatar
Paul committed
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

    static literal parse_value(const onnx::AttributeProto& attr)
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
2470
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
2471
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
2472
2473
2474
2475
2476
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
2477
2478
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
2479
2480
        case onnx::AttributeProto::GRAPHS: return {};
        }
Paul's avatar
Paul committed
2481
        MIGRAPHX_THROW("Invalid attribute type");
Paul's avatar
Paul committed
2482
2483
2484
2485
2486
    }

    static literal parse_tensor(const onnx::TensorProto& t)
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
2487
2488
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
2489
            const std::string& s = t.raw_data();
Scott Thornton's avatar
Scott Thornton committed
2490
2491
            switch(t.data_type())
            {
2492
            case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
Khalique's avatar
Khalique committed
2493
2494
2495
2496
            case onnx::TensorProto::FLOAT16:
                return create_literal(shape::half_type, dims, s.data());
            case onnx::TensorProto::DOUBLE:
                return create_literal(shape::double_type, dims, s.data());
2497
            case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
Paul's avatar
Paul committed
2498
2499
2500
2501
            case onnx::TensorProto::INT8:
            case onnx::TensorProto::UINT16:
            case onnx::TensorProto::INT16:
            case onnx::TensorProto::INT32:
2502
            case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
Paul's avatar
Paul committed
2503
2504
2505
2506
2507
2508
            case onnx::TensorProto::UINT8:
            case onnx::TensorProto::STRING:
            case onnx::TensorProto::UNDEFINED:
            case onnx::TensorProto::UINT32:
            case onnx::TensorProto::UINT64:
            case onnx::TensorProto::COMPLEX64:
Scott Thornton's avatar
Scott Thornton committed
2509
2510
            case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
            }
Paul's avatar
Paul committed
2511
            MIGRAPHX_THROW("Invalid tensor type");
2512
        }
Paul's avatar
Paul committed
2513
2514
2515
2516
2517
2518
        switch(t.data_type())
        {
        case onnx::TensorProto::INT8:
        case onnx::TensorProto::UINT16:
        case onnx::TensorProto::INT16:
        case onnx::TensorProto::INT32:
Paul's avatar
Paul committed
2519
        case onnx::TensorProto::BOOL:
Khalique's avatar
Khalique committed
2520
            return create_literal(shape::int32_type, dims, t.int32_data());
Paul's avatar
Paul committed
2521
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2522
            return create_literal(shape::int64_type, dims, t.int64_data());
Paul's avatar
Paul committed
2523
2524
2525
2526
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
2527
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2528
        {
Khalique's avatar
Khalique committed
2529
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2530
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2531
2532
2533
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
2534
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
2535
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
2536
        }
Paul's avatar
Paul committed
2537
2538
2539
2540
2541
2542
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::UINT8:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::UINT32:
        case onnx::TensorProto::UINT64:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
2543
2544
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Paul's avatar
Paul committed
2545
        MIGRAPHX_THROW("Invalid tensor type");
Paul's avatar
Paul committed
2546
2547
    }

Khalique's avatar
Khalique committed
2548
    static literal
2549
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
2550
    {
Khalique's avatar
Khalique committed
2551
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
2552
        if(dims.empty())
2553
            return literal{{shape_type}, data};
2554
2555
2556
        return literal{{shape_type, dims}, data};
    }

2557
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
2558
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
2559
2560
    {
        if(dims.empty())
2561
            return literal{{shape_type}, data.begin(), data.end()};
2562
        return literal{{shape_type, dims}, data.begin(), data.end()};
2563
2564
    }

2565
    shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims)
Paul's avatar
Paul committed
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
    {
        shape::type_t shape_type{};
        switch(t.tensor_type().elem_type())
        {
        case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
        case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
        case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
        case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
        case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
        case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
Paul's avatar
Paul committed
2576
        case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
Paul's avatar
Paul committed
2577
2578
2579
        case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
        case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
        case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
2580
        case onnx::TensorProto::UINT8: shape_type = shape::uint8_type; break;
Paul's avatar
Paul committed
2581
2582
2583
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::BOOL:
        case onnx::TensorProto::UNDEFINED:
Paul's avatar
Paul committed
2584
2585
        case onnx::TensorProto::COMPLEX64:
        case onnx::TensorProto::COMPLEX128:
Paul's avatar
Paul committed
2586
            break; // throw std::runtime_error("Unsupported type");
Paul's avatar
Paul committed
2587
        }
2588
2589
2590
2591
2592
2593

        if(!input_dims.empty())
        {
            return {shape_type, input_dims};
        }

Paul's avatar
Paul committed
2594
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
2595
        auto&& tensor_dims = t.tensor_type().shape().dim();
2596
2597
2598
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
2599
2600
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
2601
                           {
2602
                               if(static_cast<int>(d.dim_value()) <= 0)
2603
2604
2605
                               {
                                   return default_dim_value;
                               }
2606
                               return d.dim_value();
2607
                           }
2608
2609
2610
2611
                           else
                           {
                               return default_dim_value;
                           }
2612
                       });
2613

2614
2615
2616
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
2617
2618
        return {shape_type, dims};
    }
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640

    shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }
Shucai Xiao's avatar
Shucai Xiao committed
2641
2642
2643

    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
2644
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
2645
2646
2647
2648
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
2649
2650
};

Paul Fultz II's avatar
Paul Fultz II committed
2651
template <class... Ts>
2652
program parse_onnx_from(const onnx_options& options, Ts&&... xs)
Paul's avatar
Paul committed
2653
2654
{
    onnx_parser parser;
2655
2656
2657
    parser.map_input_dims         = options.map_input_dims;
    parser.default_dim_value      = options.default_dim_value;
    parser.skip_unknown_operators = options.skip_unknown_operators;
2658

2659
    if(options.print_program_on_error)
Paul's avatar
Paul committed
2660
    {
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
        // Log the program when it can't be parsed
        try
        {
            parser.parse_from(std::forward<Ts>(xs)...);
        }
        catch(...)
        {
            std::cerr << parser.prog << std::endl;
            throw;
        }
Paul's avatar
Paul committed
2671
    }
2672
    else
Paul's avatar
Paul committed
2673
    {
2674
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
2675
2676
2677
2678
    }
    return std::move(parser.prog);
}

2679
program parse_onnx(const std::string& name, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2680
2681
2682
2683
2684
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    return parse_onnx_from(options, input);
}

2685
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2686
2687
2688
2689
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

2690
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
2691
2692
2693
2694
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
2695
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
2696
} // namespace migraphx