onnx.cpp 116 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
Paul's avatar
Paul committed
9
#include <utility>
10
#include <vector>
Paul's avatar
Paul committed
11

Paul's avatar
Paul committed
12
13
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
14
#include <migraphx/make_op.hpp>
Paul's avatar
Paul committed
15
16
17
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
18
#include <migraphx/onnx.hpp>
19
#include <migraphx/pad_calc.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
20
21
#include <migraphx/type_traits.hpp>
#include <migraphx/float_equal.hpp>
22
23
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
Paul's avatar
Paul committed
24

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_cell_output.hpp>
#include <migraphx/op/rnn_last_hs_output.hpp>
#include <migraphx/op/rnn_variable_seq_lens.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/undefined.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/op/unsqueeze.hpp>

Paul's avatar
Paul committed
50
namespace migraphx {
Paul's avatar
Paul committed
51
inline namespace MIGRAPHX_INLINE_NS {
Paul's avatar
Paul committed
52

53
54
namespace onnx = onnx_for_migraphx;

Paul's avatar
Paul committed
55
56
struct onnx_parser
{
57
58
    std::string filename;
    std::string path    = ".";
Paul's avatar
Paul committed
59
    using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
60
61
62
63
64
65
    struct node_info
    {
        attribute_map attributes{};
        std::size_t num_outputs = 1;
    };
    using node_map = std::unordered_map<std::string, onnx::NodeProto>;
Paul's avatar
Paul committed
66
    using op_func =
67
        std::function<std::vector<instruction_ref>(node_info, std::vector<instruction_ref>)>;
Paul's avatar
Paul committed
68
69
    node_map nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
70
    program prog                  = program();
71
    module* mm                    = prog.get_main_module();
72
73
74
    bool is_pytorch               = false;
    std::size_t default_dim_value = 1;
    std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
75
    bool skip_unknown_operators = false;
Paul's avatar
Paul committed
76
77

    std::unordered_map<std::string, op_func> ops;
78
    std::unordered_map<std::string, operation> map_actv_funcs;
Paul's avatar
Paul committed
79
80
81

    onnx_parser()
    {
82
        // sort onnx operator alphabetically through name
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        add_generic_op("Abs", "abs");
        add_generic_op("Acos", "acos");
        add_generic_op("Acosh", "acosh");
        add_generic_op("Asin", "asin");
        add_generic_op("Asinh", "asinh");
        add_generic_op("Atan", "atan");
        add_generic_op("Atanh", "atanh");
        add_generic_op("Ceil", "ceil");
        add_generic_op("Concat", "concat");
        add_generic_op("Cos", "cos");
        add_generic_op("Cosh", "cosh");
        add_generic_op("Erf", "erf");
        add_generic_op("Exp", "exp");
        add_generic_op("Flatten", "flatten");
        add_generic_op("Floor", "floor");
        add_generic_op("Gather", "gather", true);
        add_generic_op("Identity", "identity");
        add_generic_op("Log", "log");
        add_generic_op("LogSoftmax", "logsoftmax");
        add_generic_op("Neg", "neg");
        add_generic_op("Reciprocal", "recip");
        add_generic_op("Relu", "relu");
        add_generic_op("Round", "round");
        add_generic_op("Sigmoid", "sigmoid");
        add_generic_op("Sign", "sign");
        add_generic_op("Sin", "sin");
        add_generic_op("Sinh", "sinh");
        add_generic_op("Softmax", "softmax");
        add_generic_op("Sqrt", "sqrt");
        add_generic_op("Squeeze", "squeeze", true);
        add_generic_op("Tan", "tan");
        add_generic_op("Tanh", "tanh");
        add_generic_op("Unsqueeze", "unsqueeze", true);

        add_binary_op("Add", "add");
        add_binary_op("Div", "div");
        add_binary_op("Mul", "mul");
        add_binary_op("Pow", "pow");
        add_binary_op("PRelu", "prelu");
        add_binary_op("Sub", "sub");

        add_variadic_op("Sum", "add");
        add_variadic_op("Max", "max");
        add_variadic_op("Min", "min");
Paul's avatar
Paul committed
127

128
        add_mem_op("ATen", &onnx_parser::parse_aten);
129
        add_mem_op("AveragePool", &onnx_parser::parse_pooling);
130
131
        add_mem_op("ArgMax", "argmax", &onnx_parser::parse_arg_op);
        add_mem_op("ArgMin", "argmin", &onnx_parser::parse_arg_op);
132
        add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
133
        add_mem_op("Cast", &onnx_parser::parse_cast);
Khalique's avatar
Khalique committed
134
        add_mem_op("Clip", &onnx_parser::parse_clip);
Paul's avatar
Paul committed
135
        add_mem_op("Constant", &onnx_parser::parse_constant);
136
137
        add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
        add_mem_op("ConstantOfShape", &onnx_parser::parse_constant_of_shape);
138
139
        add_mem_op("Conv", "convolution", &onnx_parser::parse_conv);
        add_mem_op("ConvInteger", "quant_convolution", &onnx_parser::parse_conv);
kahmed10's avatar
kahmed10 committed
140
        add_mem_op("ConvTranspose", &onnx_parser::parse_conv_transpose);
141
        add_mem_op("Dropout", &onnx_parser::parse_dropout);
142
        add_mem_op("Elu", &onnx_parser::parse_elu);
143
        add_mem_op("Equal", "equal", &onnx_parser::parse_compare_op);
144
        add_mem_op("Expand", &onnx_parser::parse_expand);
Shucai Xiao's avatar
Shucai Xiao committed
145
        add_mem_op("GatherElements", &onnx_parser::parse_gather_elements);
Paul's avatar
Paul committed
146
        add_mem_op("Gemm", &onnx_parser::parse_gemm);
147
148
        add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
        add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
149
        add_mem_op("Greater", "greater", &onnx_parser::parse_compare_op);
150
151
        add_mem_op("GRU", &onnx_parser::parse_gru);
        add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
kahmed10's avatar
kahmed10 committed
152
        add_mem_op("InstanceNormalization", &onnx_parser::parse_instancenorm);
153
        add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
154
        add_mem_op("Less", "less", &onnx_parser::parse_compare_op);
155
        add_mem_op("LRN", &onnx_parser::parse_lrn);
156
        add_mem_op("LSTM", &onnx_parser::parse_lstm);
157
158
        add_mem_op("MatMul", "dot", &onnx_parser::parse_matmul);
        add_mem_op("MatMulInteger", "quant_dot", &onnx_parser::parse_matmul);
159
        add_mem_op("MaxPool", &onnx_parser::parse_pooling);
Shucai Xiao's avatar
Shucai Xiao committed
160
        add_mem_op("NonZero", &onnx_parser::parse_nonzero);
kahmed10's avatar
kahmed10 committed
161
        add_mem_op("OneHot", &onnx_parser::parse_onehot);
162
        add_mem_op("Pad", &onnx_parser::parse_pad);
kahmed10's avatar
kahmed10 committed
163
        add_mem_op("Range", &onnx_parser::parse_range);
Shucai Xiao's avatar
Shucai Xiao committed
164
165
166
167
        add_mem_op("ReduceL1", &onnx_parser::parse_reduce_l1);
        add_mem_op("ReduceL2", &onnx_parser::parse_reduce_l2);
        add_mem_op("ReduceLogSum", &onnx_parser::parse_reduce_log_sum);
        add_mem_op("ReduceLogSumExp", &onnx_parser::parse_reduce_log_sum_exp);
168
169
170
171
172
        add_mem_op("ReduceMax", "reduce_max", &onnx_parser::parse_reduce_oper);
        add_mem_op("ReduceMean", "reduce_mean", &onnx_parser::parse_reduce_oper);
        add_mem_op("ReduceMin", "reduce_min", &onnx_parser::parse_reduce_oper);
        add_mem_op("ReduceProd", "reduce_prod", &onnx_parser::parse_reduce_oper);
        add_mem_op("ReduceSum", "reduce_sum", &onnx_parser::parse_reduce_oper);
Shucai Xiao's avatar
Shucai Xiao committed
173
        add_mem_op("ReduceSumSquare", &onnx_parser::parse_reduce_sum_square);
174
        add_mem_op("Reshape", &onnx_parser::parse_reshape);
Shucai Xiao's avatar
Shucai Xiao committed
175
        add_mem_op("Resize", &onnx_parser::parse_resize);
176
        add_mem_op("RNN", &onnx_parser::parse_rnn);
Shucai Xiao's avatar
Shucai Xiao committed
177
        add_mem_op("Selu", &onnx_parser::parse_selu);
178
179
        add_mem_op("Shape", &onnx_parser::parse_shape);
        add_mem_op("Slice", &onnx_parser::parse_slice);
180
        add_mem_op("Split", &onnx_parser::parse_split);
kahmed10's avatar
kahmed10 committed
181
        add_mem_op("Tile", &onnx_parser::parse_tile);
182
        add_mem_op("Transpose", &onnx_parser::parse_transpose);
Shucai Xiao's avatar
Shucai Xiao committed
183
        add_mem_op("Upsample", &onnx_parser::parse_upsample);
Shucai Xiao's avatar
Shucai Xiao committed
184
        add_mem_op("Where", &onnx_parser::parse_where);
185
186
187
188
189
190
191

        // init the activation function map
        init_actv_func();
    }

    void init_actv_func()
    {
192
        // Support name format of all lower case or the first letter capital
193
194
195
196
197
198
199
        map_actv_funcs.insert(std::make_pair("tanh", make_op("tanh")));
        map_actv_funcs.insert(std::make_pair("relu", make_op("relu")));
        map_actv_funcs.insert(std::make_pair("sigmoid", make_op("sigmoid")));
        map_actv_funcs.insert(std::make_pair("leakyrelu", make_op("leaky_relu")));
        map_actv_funcs.insert(std::make_pair("elu", make_op("elu")));
    }

200
    operation load(const std::string& name, const node_info& info) const
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    {
        auto op = make_op(name);
        auto v  = op.to_value();
        for(auto&& x : v)
        {
            if(info.attributes.count(x.get_key()) == 0)
                continue;
            literal s = parse_value(info.attributes.at(x.get_key()));
            if(x.is_array())
            {
                std::vector<value> values;
                s.visit([&](auto y) {
                    std::transform(y.begin(), y.end(), std::back_inserter(values), [](auto z) {
                        return value(z);
                    });
                });
                x = values;
            }
            else
            {
                s.visit([&](auto y) { x = y.front(); });
            }
        }
        op.from_value(v);
        return op;
Paul's avatar
Paul committed
226
227
228
229
    }

    template <class F>
    void add_op(std::string name, F f)
Paul's avatar
Paul committed
230
231
232
233
234
235
236
237
238
    {
        ops.emplace(name, [=](auto&&... xs) {
            return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
        });
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
Paul's avatar
Paul committed
239
240
241
242
243
    {
        ops.emplace(name, f);
    }

    template <class F>
244
    void add_mem_op(const std::string& name, F f)
Paul's avatar
Paul committed
245
    {
Paul's avatar
Paul committed
246
        add_op(name, [=](auto&&... xs) {
Paul's avatar
Paul committed
247
248
249
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }
Khalique's avatar
Khalique committed
250

251
252
253
254
255
256
257
258
259
    template <class F>
    void add_mem_op(const std::string& onnx_name, const std::string& op_name, F f)
    {
        add_op(onnx_name, [=](auto&&... xs) {
            return std::mem_fn(f)(*this, onnx_name, op_name, std::forward<decltype(xs)>(xs)...);
        });
    }

    void add_binary_op(const std::string& onnx_name, const std::string& op_name)
260
    {
261
        add_op(onnx_name, [this, op_name](node_info info, std::vector<instruction_ref> args) {
Scott Thornton's avatar
Scott Thornton committed
262
            if(args.size() != 2)
Paul's avatar
Paul committed
263
                MIGRAPHX_THROW("binary operators should have 2 operands");
264
            if(contains(info.attributes, "broadcast") and contains(info.attributes, "axis"))
265
            {
266
                uint64_t broadcasted = parse_value(info.attributes.at("broadcast")).at<uint64_t>();
267
268
                if(broadcasted != 0)
                {
269
                    uint64_t axis = parse_value(info.attributes.at("axis")).at<uint64_t>();
270
271
272
                    auto l = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()},
                                                 args[1]);
                    return mm->add_instruction(make_op(op_name), args[0], l);
273
                }
274
                return mm->add_instruction(make_op(op_name), args);
275
            }
Paul's avatar
Paul committed
276
            else
277
            {
278
                return add_broadcastable_binary_op(args[0], args[1], op_name);
279
280
281
282
            }
        });
    }

Shucai Xiao's avatar
Shucai Xiao committed
283
284
    std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
                                                      std::vector<std::size_t> s1)
285
286
287
288
289
290
291
292
293
294
295
296
297
    {
        // Example:
        // s0 = (3,2,4,5) and s1 = (2,1,1)
        //
        // In this case we need to broadcast (:,1,1) portion of
        // s1 plus broadcast the 1st dimension of s1
        // giving output_lens = (3,2,4,5)
        //
        // Another example:
        // s0 = (3,2,1,5) and s1 = (2,7,5)
        // In this case we need to broadcast the (:,:,1:,:) axis
        // of s0 plus the 1st dimension of s1 giving
        // output_lens = (3,2,7,5)
Shucai Xiao's avatar
Shucai Xiao committed
298
        if(s0.size() > s1.size())
299
300
301
302
303
304
        {
            s0.swap(s1);
        }

        std::vector<std::size_t> out_lens(s1);
        auto offset = s1.size() - s0.size();
Shucai Xiao's avatar
Shucai Xiao committed
305
306
307
308
        std::transform(s0.begin(),
                       s0.end(),
                       s1.begin() + offset,
                       out_lens.begin() + offset,
309
                       [&](auto a, auto b) {
Shucai Xiao's avatar
Shucai Xiao committed
310
                           if(a != b and a != 1 and b != 1)
311
                           {
Shucai Xiao's avatar
Shucai Xiao committed
312
313
314
315
316
317
                               MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
                                              to_string_range(s0) + "} and {" +
                                              to_string_range(s1) + "} mismatch!");
                           }
                           return std::max(a, b);
                       });
318
319
320
321

        return out_lens;
    }

322
    instruction_ref make_contiguous(instruction_ref ins) const
Shucai Xiao's avatar
Shucai Xiao committed
323
    {
Shucai Xiao's avatar
Shucai Xiao committed
324
        if(ins->get_shape().standard())
Shucai Xiao's avatar
Shucai Xiao committed
325
326
327
328
        {
            return ins;
        }

329
        return mm->add_instruction(make_op("contiguous"), ins);
Shucai Xiao's avatar
Shucai Xiao committed
330
331
    }

332
333
    instruction_ref
    add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, const std::string& name)
Khalique's avatar
Khalique committed
334
    {
Khalique's avatar
Khalique committed
335
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
336
337
        {
            // Get lengths for both arguments
Shucai Xiao's avatar
Shucai Xiao committed
338
339
            auto s0       = arg0->get_shape().lens();
            auto s1       = arg1->get_shape().lens();
340
            auto out_lens = compute_broadcasted_lens(s0, s1);
341
342
343

            auto l0 = arg0;
            if(arg0->get_shape().lens() != out_lens)
344
                l0 = mm->add_instruction(op::multibroadcast{out_lens}, arg0);
345
346
347

            auto l1 = arg1;
            if(arg1->get_shape().lens() != out_lens)
348
                l1 = mm->add_instruction(op::multibroadcast{out_lens}, arg1);
349

350
            return mm->add_instruction(make_op(name), l0, l1);
Khalique's avatar
Khalique committed
351
352
353
        }
        else
        {
354
            return mm->add_instruction(make_op(name), {arg0, arg1});
Khalique's avatar
Khalique committed
355
        }
356
357
    }

358
359
360
    void add_generic_op(const std::string& onnx_name,
                        const std::string& op_name,
                        bool contiguous = false)
Paul's avatar
Paul committed
361
    {
362
363
364
365
366
367
368
369
370
371
        add_op(
            onnx_name,
            [this, op_name, contiguous](const node_info& info, std::vector<instruction_ref> args) {
                auto op = load(op_name, info);
                if(contiguous)
                {
                    std::transform(args.begin(), args.end(), args.begin(), [&](auto arg) {
                        return this->make_contiguous(arg);
                    });
                }
372
                return mm->add_instruction(op, args);
373
            });
Paul's avatar
Paul committed
374
375
    }

376
    void add_variadic_op(const std::string& onnx_name, const std::string& op_name)
Khalique's avatar
Khalique committed
377
    {
378
        add_op(onnx_name, [this, op_name](const node_info&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
379
            return std::accumulate(std::next(args.begin()),
Khalique's avatar
Khalique committed
380
381
                                   args.end(),
                                   args.front(),
382
383
                                   [this, op_name](instruction_ref a, instruction_ref b) {
                                       return add_broadcastable_binary_op(a, b, op_name);
Khalique's avatar
Khalique committed
384
                                   });
Khalique's avatar
Khalique committed
385
        });
Khalique's avatar
Khalique committed
386
387
    }

kahmed10's avatar
kahmed10 committed
388
389
390
391
392
393
394
    template <class T>
    std::vector<int64_t> to_int64_vector(const std::vector<T>& input_vector)
    {
        std::vector<int64_t> output_vector(input_vector.begin(), input_vector.end());
        return output_vector;
    }

395
396
397
    instruction_ref add_bias(const std::vector<instruction_ref>& args,
                             instruction_ref curr_ins,
                             uint64_t axis) const
kahmed10's avatar
kahmed10 committed
398
399
400
401
    {
        if(args.size() == 3)
        {
            auto bias_bcast =
402
403
                mm->add_instruction(op::broadcast{axis, curr_ins->get_shape().lens()}, args[2]);
            return mm->add_instruction(make_op("add"), curr_ins, bias_bcast);
kahmed10's avatar
kahmed10 committed
404
405
406
407
        }
        return curr_ins;
    }

408
    static bool is_asym_padding(const std::vector<int64_t>& padding)
409
    {
410
411
412
413
414
415
416
        assert(padding.size() % 2 == 0);
        size_t pad_ndims = padding.size() / 2;

        for(size_t i = 0; i < pad_ndims; i++)
        {
            if(padding[i] != padding[i + pad_ndims])
            {
kahmed10's avatar
kahmed10 committed
417
                return true;
418
419
            }
        }
kahmed10's avatar
kahmed10 committed
420
421
        return false;
    }
422

kahmed10's avatar
kahmed10 committed
423
424
    void check_asym_padding(instruction_ref& ins,
                            const std::vector<int64_t>& padding,
425
                            value& v,
426
                            int count_include_pad = 0,
427
                            float pad_val         = 0) const
kahmed10's avatar
kahmed10 committed
428
429
430
431
432
    {
        size_t pad_ndims  = padding.size() / 2;
        auto left_pad_it  = padding.begin();
        auto right_pad_it = left_pad_it + pad_ndims;

433
        if(is_asym_padding(padding) or count_include_pad == 1)
434
        {
435
436
437
438
439
            std::vector<int64_t> asym_pads{0, 0, 0, 0}; // don't pad N and C
            // add left pads
            asym_pads.insert(asym_pads.begin() + 2, left_pad_it, right_pad_it);
            // add right pads
            asym_pads.insert(asym_pads.begin() + pad_ndims + 4, right_pad_it, padding.end());
440
            ins = mm->add_instruction(op::pad{asym_pads, pad_val}, ins);
441
442
443
        }
        else
        {
444
            v["padding"] = std::vector<size_t>(left_pad_it, right_pad_it);
445
446
447
        }
    }

448
    instruction_ref
449
    parse_clip(const std::string&, node_info info, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
450
    {
kahmed10's avatar
kahmed10 committed
451
452
453
454
455
456
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

Shucai Xiao's avatar
Shucai Xiao committed
457
        if(args.size() == 3 and args[2]->name() != "undefined")
Khalique's avatar
Khalique committed
458
        {
kahmed10's avatar
kahmed10 committed
459
460
            max_arg  = args[2];
            max_used = true;
Khalique's avatar
Khalique committed
461
        }
Shucai Xiao's avatar
Shucai Xiao committed
462
463

        if(args.size() >= 2 and args[1]->name() != "undefined")
Khalique's avatar
Khalique committed
464
        {
kahmed10's avatar
kahmed10 committed
465
466
467
468
469
470
471
472
473
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parse_value(info.attributes.at("min")).at<float>();
            float max_val = parse_value(info.attributes.at("max")).at<float>();
474
475
            min_arg       = mm->add_literal(min_val);
            max_arg       = mm->add_literal(max_val);
kahmed10's avatar
kahmed10 committed
476
477
            min_used      = true;
            max_used      = true;
Khalique's avatar
Khalique committed
478
        }
kahmed10's avatar
kahmed10 committed
479
480

        if(min_used)
Shucai Xiao's avatar
Shucai Xiao committed
481
        {
482
            min_arg = mm->add_instruction(op::multibroadcast{input_lens}, min_arg);
Shucai Xiao's avatar
Shucai Xiao committed
483
        }
kahmed10's avatar
kahmed10 committed
484
485

        if(max_used)
Shucai Xiao's avatar
Shucai Xiao committed
486
        {
487
            max_arg = mm->add_instruction(op::multibroadcast{input_lens}, max_arg);
Shucai Xiao's avatar
Shucai Xiao committed
488
        }
kahmed10's avatar
kahmed10 committed
489
490

        if(min_used and max_used)
Shucai Xiao's avatar
Shucai Xiao committed
491
        {
492
            return mm->add_instruction(make_op("clip"), args[0], min_arg, max_arg);
Shucai Xiao's avatar
Shucai Xiao committed
493
494
495
        }
        else if(max_used)
        {
496
            return mm->add_instruction(make_op("min"), args[0], max_arg);
Shucai Xiao's avatar
Shucai Xiao committed
497
498
499
        }
        else if(min_used)
        {
500
            return mm->add_instruction(make_op("max"), args[0], min_arg);
Shucai Xiao's avatar
Shucai Xiao committed
501
502
503
        }
        else
        {
504
            return mm->add_instruction(make_op("identity"), args[0]);
Shucai Xiao's avatar
Shucai Xiao committed
505
        }
Shucai Xiao's avatar
Shucai Xiao committed
506
507
    }

508
509
510
    instruction_ref parse_arg_op(const std::string&,
                                 const std::string& op_name,
                                 node_info info,
511
                                 std::vector<instruction_ref> args) const
512
    {
513
        int64_t axis = 0;
514
        if(contains(info.attributes, "axis"))
515
        {
516
            axis = static_cast<int64_t>(parse_value(info.attributes.at("axis")).at<int>());
517
518
        }

Shucai Xiao's avatar
Shucai Xiao committed
519
        int keep_dims = 1;
520
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
521
        {
522
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
523
524
        }

Shucai Xiao's avatar
Shucai Xiao committed
525
        if(keep_dims == 0)
526
        {
527
528
            auto ins = mm->add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
            return mm->add_instruction(op::squeeze{{axis}}, ins);
529
530
531
        }
        else
        {
532
            return mm->add_instruction(make_op(op_name, {{"axis", axis}}), std::move(args));
533
        }
534
535
    }

kahmed10's avatar
kahmed10 committed
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    void calc_reflect_indices(std::vector<int>& indices, const int64_t num_dims)
    {
        int k         = 0;
        bool reversed = false;
        // in reflect padding, if the num_pads > num_dims,
        // compute the extra pad indices periodically, ex. ( 1, 2, 3, 2, 1, 0)
        for(int& idx : indices)
        {
            if(k == num_dims - 1)
                reversed = true;
            if(k == 0)
                reversed = false;
            if(reversed)
                k--;
            else
                k++;
            idx = k;
        }
    }

    instruction_ref reflect_pad(const std::vector<int64_t>& pads, instruction_ref input)
    {
        size_t num_dims = pads.size() / 2;
        std::vector<int> ldims(pads.begin(), pads.begin() + num_dims);
        std::vector<int> rdims(pads.begin() + num_dims, pads.end());
        assert(ldims.size() == rdims.size());

        std::vector<int64_t> axes(num_dims);
        std::iota(axes.begin(), axes.end(), int64_t{0});

        // iterate over dimensions, starting from lowest dimension
        for(int64_t i = num_dims - 1; i >= 0; i--)
        {
            auto axis   = i;
            auto lcount = ldims.at(i);
            auto rcount = rdims.at(i);
            if(lcount == 0 and rcount == 0) // no padding for current dim
                continue;

            // calculate starts and ends for each iteration since shape may change
            std::vector<size_t> dims = input->get_shape().lens();
            std::vector<int64_t> starts(axes.size(), 0);
            std::vector<int64_t> ends(dims.begin(), dims.end());
            std::vector<instruction_ref> slices;

            auto starts_it = starts.begin() + i;
            auto ends_it   = ends.begin() + i;
            auto dims_it   = dims.begin() + i;

            std::vector<int> l_indices(lcount);
            std::vector<int> r_indices(rcount);

            // compute slice indices in a periodic fashion
            calc_reflect_indices(l_indices, *dims_it);
            calc_reflect_indices(r_indices, *dims_it);

            for(int idx : l_indices)
            {
                *starts_it = idx;
                *ends_it   = *starts_it + 1;
596
                slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input));
kahmed10's avatar
kahmed10 committed
597
598
599
600
601
602
603
604
            }
            // when padding on the left side, the outermost pad should be at the beginning
            std::reverse(slices.begin(), slices.end());
            slices.push_back(input);
            for(int idx : r_indices)
            {
                *starts_it = *dims_it - idx - 1;
                *ends_it   = *starts_it + 1;
605
                slices.push_back(mm->add_instruction(op::slice{axes, starts, ends}, input));
kahmed10's avatar
kahmed10 committed
606
            }
607
            input = mm->add_instruction(op::concat{axis}, slices);
kahmed10's avatar
kahmed10 committed
608
609
610
611
        }
        return input;
    }

612
613
614
615
616
617
618
619
620
    void check_attr_sizes(size_t kdims, size_t attr_size, const std::string& error_msg)
    {
        if(kdims != attr_size)
        {
            MIGRAPHX_THROW(error_msg + " k-dims: " + to_string(kdims) +
                           " attribute size: " + to_string(attr_size));
        }
    }

621
    void recalc_conv_attributes(value& v, size_t kdims)
622
    {
623
        if(v["padding"].size() != kdims)
624
        {
625
626
            v["padding"].resize(kdims);
            std::fill_n(v["padding"].begin(), kdims, 0);
627
        }
628
        if(v["stride"].size() != kdims)
629
        {
630
631
            v["stride"].resize(kdims);
            std::fill_n(v["stride"].begin(), kdims, 1);
632
        }
633
        if(v["dilation"].size() != kdims)
634
        {
635
636
            v["dilation"].resize(kdims);
            std::fill_n(v["dilation"].begin(), kdims, 1);
637
638
639
        }
    }

640
    static void cal_auto_padding_size(node_info info,
641
                                      value& v,
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
                                      const std::vector<std::size_t>& k_lens,
                                      const std::vector<std::size_t>& dilation,
                                      const std::vector<std::size_t>& in_lens,
                                      std::vector<int64_t>& paddings)
    {
        size_t kdims = in_lens.size() - 2;
        assert(k_lens.size() == kdims and dilation.size() == kdims);

        if(!contains(info.attributes, "auto_pad"))
        {
            return;
        }

        auto auto_pad = info.attributes["auto_pad"].s();
        if(auto_pad.find("SAME") != std::string::npos)
        {
            bool is_same_upper = (auto_pad.find("SAME_UPPER") != std::string::npos);
            paddings.resize(2 * kdims);

            for(size_t i = 0; i < paddings.size() / 2; i++)
            {
                calculate_padding(i,
                                  paddings,
                                  in_lens[i + 2],
666
                                  v["stride"][i].to<int64_t>(),
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
                                  dilation[i],
                                  k_lens[i],
                                  is_same_upper);
            }
        }
    }

    static void check_padding_mode(node_info info, const std::string& op_name)
    {
        // ensure pads availabe only when auto_pad is "NOT_SET"
        if(contains(info.attributes, "pads") and contains(info.attributes, "auto_pad"))
        {
            auto s = info.attributes["auto_pad"].s();
            if(to_upper(s) != "NOTSET")
            {
                MIGRAPHX_THROW("PARSE_" + op_name +
                               ": auto_pad and padding cannot be specified simultaneously");
            }
        }
    }

688
689
690
691
    instruction_ref parse_conv(const std::string&,
                               const std::string& op_name,
                               node_info info,
                               std::vector<instruction_ref> args)
Paul's avatar
Paul committed
692
    {
693
694
        auto op      = make_op(op_name);
        auto values  = op.to_value();
695
696
        auto l0      = args[0];
        auto weights = args[1];
697
698
699
700
        auto in_lens = l0->get_shape().lens();
        assert(in_lens.size() > 2);
        auto kdims = in_lens.size() - 2;

701
702
703
        // ensure pads availabe only when auto_pad is "NOT_SET"
        check_padding_mode(info, "CONV");

704
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
705
        {
706
707
708
            values["stride"].clear();
            copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
            check_attr_sizes(kdims, values["stride"].size(), "PARSE_CONV: inconsistent strides");
Paul's avatar
Paul committed
709
        }
710
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
711
        {
712
713
714
715
            values["dilation"].clear();
            copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
            check_attr_sizes(
                kdims, values["dilation"].size(), "PARSE_CONV: inconsistent dilations");
Paul's avatar
Paul committed
716
        }
717
718
719
720

        std::vector<int64_t> padding;
        if(contains(info.attributes, "pads"))
        {
721
            values["padding"].clear();
722
723
724
725
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
            check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
        }

726
        if(contains(info.attributes, "auto_pad"))
727
        {
728
729
            auto weight_lens = weights->get_shape().lens();
            std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
730
731
732
733
734
735
            cal_auto_padding_size(info,
                                  values,
                                  k_lens,
                                  values["dilation"].to_vector<std::size_t>(),
                                  in_lens,
                                  padding);
Shucai Xiao's avatar
Shucai Xiao committed
736
737
738
739
740
            auto auto_pad = info.attributes["auto_pad"].s();
            if(auto_pad.find("SAME") != std::string::npos)
            {
                values["padding_mode"] = to_value(op::padding_mode_t::same);
            }
741
        }
742
        check_asym_padding(l0, padding, values);
743

744
        if(contains(info.attributes, "group"))
Khalique's avatar
Khalique committed
745
        {
746
            values["group"] = parse_value(info.attributes.at("group")).at<int>();
Khalique's avatar
Khalique committed
747
        }
kahmed10's avatar
kahmed10 committed
748

749
        recalc_conv_attributes(values, kdims);
750

751
        op.from_value(values);
752
        auto l1 = mm->add_instruction(op, l0, args[1]);
kahmed10's avatar
kahmed10 committed
753
754
755
        return add_bias(args, l1, 1);
    }

756
757
    instruction_ref
    parse_conv_transpose(const std::string&, node_info info, std::vector<instruction_ref> args)
kahmed10's avatar
kahmed10 committed
758
    {
759
760
761
        operation op = make_op("deconvolution");
        value values = op.to_value();
        // op::deconvolution op;
kahmed10's avatar
kahmed10 committed
762
763
        auto l0 = args[0];
        std::vector<std::int64_t> padding;
kahmed10's avatar
kahmed10 committed
764
765
766
767
768
        bool asym_padding = false;
        auto in_lens      = l0->get_shape().lens();
        assert(in_lens.size() > 2);
        auto kdims = in_lens.size() - 2;

769
770
771
        // ensure pads availabe only when auto_pad is "NOT_SET"
        check_padding_mode(info, "CONV_TRANSPOSE");

772
        if(contains(info.attributes, "pads"))
kahmed10's avatar
kahmed10 committed
773
        {
774
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
kahmed10's avatar
kahmed10 committed
775
776
777
778

            asym_padding = is_asym_padding(padding);

            if(not asym_padding)
kahmed10's avatar
kahmed10 committed
779
            {
kahmed10's avatar
kahmed10 committed
780
781
                size_t pad_ndims = padding.size() / 2;
                check_attr_sizes(kdims, pad_ndims, "PARSE_CONV_TRANSPOSE: inconsistent paddings");
782
                values["padding"].clear();
kahmed10's avatar
kahmed10 committed
783
784
                std::transform(padding.begin(),
                               padding.begin() + pad_ndims,
785
                               std::back_inserter(values["padding"]),
kahmed10's avatar
kahmed10 committed
786
                               [](auto pad_val) { return pad_val; });
kahmed10's avatar
kahmed10 committed
787
788
            }
        }
789
        if(contains(info.attributes, "strides"))
kahmed10's avatar
kahmed10 committed
790
        {
791
792
793
794
            values["stride"].clear();
            copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
            check_attr_sizes(
                kdims, values["stride"].size(), "PARSE_CONV_TRANSPOSE: inconsistent strides");
kahmed10's avatar
kahmed10 committed
795
        }
796
        if(contains(info.attributes, "dilations"))
Paul's avatar
Paul committed
797
        {
798
799
            values["dilation"].clear();
            copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
kahmed10's avatar
kahmed10 committed
800
            check_attr_sizes(
801
                kdims, values["dilation"].size(), "PARSE_CONV_TRANSPOSE: inconsistent dilations");
Paul's avatar
Paul committed
802
        }
803
        if(contains(info.attributes, "auto_pad"))
kahmed10's avatar
kahmed10 committed
804
        {
805
806
            auto s = info.attributes["auto_pad"].s();
            if(contains(info.attributes, "pads") and to_upper(s) != "NOTSET")
kahmed10's avatar
kahmed10 committed
807
            {
kahmed10's avatar
kahmed10 committed
808
809
                MIGRAPHX_THROW("PARSE_CONV_TRANSPOSE: auto_pad and padding cannot be specified "
                               "simultaneously");
kahmed10's avatar
kahmed10 committed
810
811
812
813
            }

            if(s.find("SAME") != std::string::npos)
            {
814
                values["padding_mode"] = to_value(op::padding_mode_t::same);
kahmed10's avatar
kahmed10 committed
815
816
817
            }
        }

818
        if(contains(info.attributes, "group"))
kahmed10's avatar
kahmed10 committed
819
        {
820
            values["group"] = parse_value(info.attributes.at("group")).at<int>();
kahmed10's avatar
kahmed10 committed
821
822
        }

823
        recalc_conv_attributes(values, kdims);
kahmed10's avatar
kahmed10 committed
824

825
        op.from_value(values);
826
        auto l1                   = mm->add_instruction(op, l0, args[1]);
kahmed10's avatar
kahmed10 committed
827
        std::vector<int64_t> dims = to_int64_vector(l1->get_shape().lens());
kahmed10's avatar
kahmed10 committed
828
829
        std::vector<int64_t> curr_shape(dims.begin() + 2, dims.end());
        if(asym_padding)
kahmed10's avatar
kahmed10 committed
830
        {
kahmed10's avatar
kahmed10 committed
831
832
833
834
835
836
837
838
839
840
841
842
            std::vector<int64_t> axes(kdims);
            std::iota(axes.begin(), axes.end(), 2); // ignore first 2 dims

            auto pad_kdim_start = padding.begin() + kdims;
            std::vector<int64_t> starts(padding.begin(), pad_kdim_start);

            std::vector<int64_t> ends{};
            std::transform(curr_shape.begin(),
                           curr_shape.end(),
                           pad_kdim_start,
                           std::back_inserter(ends),
                           [](auto curr_dim, auto pad_dim) { return curr_dim - pad_dim; });
kahmed10's avatar
kahmed10 committed
843

844
            l1 = mm->add_instruction(op::slice{axes, starts, ends}, l1);
kahmed10's avatar
kahmed10 committed
845
846
        }

847
        if(contains(info.attributes, "output_padding"))
kahmed10's avatar
kahmed10 committed
848
        {
kahmed10's avatar
kahmed10 committed
849
850
            size_t non_kdims = dims.size() * 2 - kdims;
            std::vector<int64_t> output_padding(non_kdims, 0);
851
            copy(info.attributes["output_padding"].ints(), std::back_inserter(output_padding));
kahmed10's avatar
kahmed10 committed
852
853
854
            check_attr_sizes(kdims,
                             output_padding.size() - non_kdims,
                             "PARSE_CONV_TRANSPOSE: inconsistent output padding");
855
            l1 = mm->add_instruction(op::pad{output_padding}, l1);
kahmed10's avatar
kahmed10 committed
856
857
        }

858
        if(contains(info.attributes, "output_shape"))
kahmed10's avatar
kahmed10 committed
859
860
        {
            std::vector<int64_t> output_shape;
861
            copy(info.attributes["output_shape"].ints(), std::back_inserter(output_shape));
kahmed10's avatar
kahmed10 committed
862
863
864
865
            check_attr_sizes(
                kdims, output_shape.size(), "PARSE_CONV_TRANSPOSE: inconsistent output shape");
            dims = to_int64_vector(l1->get_shape().lens());
            copy(dims.begin() + 2, dims.end(), curr_shape.begin());
kahmed10's avatar
kahmed10 committed
866
867
            if(curr_shape != output_shape)
            {
kahmed10's avatar
kahmed10 committed
868
869
870
871
872
873
                std::vector<int64_t> target_padding(dims.size() * 2 - kdims, 0);
                std::transform(output_shape.begin(),
                               output_shape.end(),
                               curr_shape.begin(),
                               std::back_inserter(target_padding),
                               [](auto out_dim, auto curr_dim) { return out_dim - curr_dim; });
874
                l1 = mm->add_instruction(op::pad{target_padding}, l1);
kahmed10's avatar
kahmed10 committed
875
876
877
878
            }
        }

        return add_bias(args, l1, 1);
Paul's avatar
Paul committed
879
    }
Paul's avatar
Paul committed
880

881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
    static void
    tune_padding_to_symmetric(int64_t& left, int64_t& right, const int stride, int64_t& s_start)
    {
        s_start = 0;
        if(left > right)
        {
            right = left;
        }
        else if(left < right)
        {
            auto diff = right - left;
            s_start   = (diff + stride - 1) / stride;
            left      = left + s_start * stride;
            right     = left;
        }
    }

898
    static void tune_padding_size(const value& v,
899
900
901
902
903
                                  std::vector<int64_t>& padding,
                                  int count_include_pad,
                                  std::vector<int64_t>& s_start)
    {
        // maxpooling or count_include_pad is 1, no change is required.
904
        if(v.at("mode").to<std::string>() == "max" or count_include_pad == 1)
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
        {
            return;
        }

        // if padding is symmetric, return directly
        if(!is_asym_padding(padding))
        {
            return;
        }

        // asymmetric padding, make it symmetric
        std::size_t n_dims = padding.size() / 2;
        s_start.resize(n_dims);
        for(std::size_t i = 0; i < n_dims; ++i)
        {
920
921
            tune_padding_to_symmetric(
                padding[i], padding[i + n_dims], v.at("stride")[i].to<int64_t>(), s_start[i]);
922
923
924
        }
    }

925
926
    instruction_ref
    parse_pooling(const std::string& name, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
927
    {
928
929
930
931
932
        std::string mode = ends_with(name, "MaxPool") ? "max" : "average";
        operation op     = make_op("pooling", {{"mode", mode}});
        value values     = op.to_value();
        auto l0          = args[0];
        auto in_lens     = l0->get_shape().lens();
933
934
935
        assert(in_lens.size() > 2);
        auto kdims = in_lens.size() - 2;

Khalique's avatar
Khalique committed
936
        if(starts_with(name, "Global"))
937
        {
938
            values["lengths"] = std::vector<size_t>(in_lens.begin() + 2, in_lens.end());
939
        }
940

941
942
        // does not support ceil_mode
        if(contains(info.attributes, "ceil_mode"))
Paul's avatar
Paul committed
943
        {
Shucai Xiao's avatar
Shucai Xiao committed
944
            values["ceil_mode"] = static_cast<bool>(info.attributes.at("ceil_mode").i());
945
        }
946

947
948
949
950
951
952
        // count include padding, if count include pad is 1, we always use
        // explicit pad
        int count_include_pad = 0;
        if(contains(info.attributes, "count_include_pad"))
        {
            count_include_pad = info.attributes.at("count_include_pad").i();
Paul's avatar
Paul committed
953
        }
954

955
        if(contains(info.attributes, "strides"))
Paul's avatar
Paul committed
956
        {
957
958
959
            values["stride"].clear();
            copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
            check_attr_sizes(kdims, values["stride"].size(), "PARSE_POOLING: inconsistent strides");
Paul's avatar
Paul committed
960
        }
961
        if(contains(info.attributes, "kernel_shape"))
Paul's avatar
Paul committed
962
        {
963
964
965
966
            values["lengths"].clear();
            copy(info.attributes["kernel_shape"].ints(), std::back_inserter(values["lengths"]));
            check_attr_sizes(
                kdims, values["lengths"].size(), "PARSE_POOLING: inconsistent lengths");
Paul's avatar
Paul committed
967
        }
968

969
970
971
972
        // ensure pads availabe only when auto_pad is "NOT_SET"
        check_padding_mode(info, "POOLING");

        std::vector<int64_t> paddings;
973
        float pad_val = ((mode == "max") ? std::numeric_limits<float>::lowest() : 0.0f);
974
975
        if(contains(info.attributes, "pads"))
        {
976
            values["padding"].clear();
977
978
979
980
981
            copy(info.attributes["pads"].ints(), std::back_inserter(paddings));
            check_attr_sizes(
                kdims, paddings.size() / 2, "PARSE_POOLING: inconsistent explicit paddings");
        }

982
        if(contains(info.attributes, "auto_pad"))
983
        {
984
            values["padding"].clear();
985
            // return paddings could be empty, then setting to 0 for no padding
986
987
988
989
990
991
            cal_auto_padding_size(info,
                                  values,
                                  values["lengths"].to_vector<std::size_t>(),
                                  {1, 1},
                                  in_lens,
                                  paddings);
992
        }
993

994
995
996
997
        if(paddings.size() != 2 * kdims)
        {
            paddings.resize(kdims * 2);
            std::fill_n(paddings.begin(), 2 * kdims, 0);
998
999
        }

1000
        if(values["padding"].size() != kdims)
1001
        {
1002
1003
            values["padding"].resize(kdims);
            std::fill_n(values["padding"].begin(), kdims, 0);
1004
        }
1005

1006
        if(values["stride"].size() != kdims)
1007
        {
1008
1009
            values["stride"].resize(kdims);
            std::fill_n(values["stride"].begin(), kdims, 1);
1010
        }
1011
1012
1013
1014
1015
        // used to calculate the supposed output shape
        std::vector<int64_t> orig_padding(paddings.begin(), paddings.end());

        std::vector<int64_t> slice_start;
        std::vector<int64_t> slice_end;
1016
        tune_padding_size(values, paddings, count_include_pad, slice_start);
1017
1018
1019
1020
1021
1022
1023
1024

        if(!slice_start.empty())
        {
            // calculate expected output shape
            orig_padding.insert(orig_padding.begin() + kdims, 2, 0);
            orig_padding.insert(orig_padding.begin(), 2, 0);
            op::pad pad{orig_padding, 0.0f};
            shape padded_shape = pad.compute_shape({l0->get_shape()});
1025
            auto out_lens      = make_op("pooling", values).compute_shape({padded_shape}).lens();
1026

1027
1028
1029
1030
1031
1032
1033
1034
1035
            // compute slice_end information
            slice_end.resize(slice_start.size());
            std::transform(out_lens.begin() + 2,
                           out_lens.end(),
                           slice_start.begin(),
                           slice_end.begin(),
                           [](auto i, auto j) { return i + j; });
        }

1036
        check_asym_padding(l0, paddings, values, count_include_pad, pad_val);
1037
        in_lens = l0->get_shape().lens();
1038
1039
        for(size_t i = 0; i < kdims; i++)
        {
1040
1041
            if(values["lengths"][i].to<int64_t>() >
               in_lens[i + 2] + 2 * values["padding"][i].to<int64_t>())
1042
            {
1043
                MIGRAPHX_THROW("PARSE_POOLING: kernel shape is too large");
1044
1045
            }
        }
1046
        op.from_value(values);
1047
        auto l1 = mm->add_instruction(op, l0);
1048
1049
1050
1051
        if(!slice_start.empty())
        {
            std::vector<int64_t> axes(kdims);
            std::iota(axes.begin(), axes.end(), 2);
1052
            l1 = mm->add_instruction(op::slice{axes, slice_start, slice_end}, l1);
1053
1054
        }

1055
        return l1;
Paul's avatar
Paul committed
1056
1057
    }

Paul's avatar
Paul committed
1058
    instruction_ref
1059
    parse_reshape(const std::string&, node_info info, std::vector<instruction_ref> args)
Paul's avatar
Paul committed
1060
    {
1061
        op::reshape op;
Paul's avatar
Paul committed
1062
1063
        if(args.size() == 1)
        {
1064
            literal s = parse_value(info.attributes.at("shape"));
1065
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
1066
1067
1068
        }
        if(args.size() == 2)
        {
Paul's avatar
Paul committed
1069
            auto s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1070
            check_arg_empty(s, "Reshape: dynamic shape is not supported");
Paul's avatar
Paul committed
1071
            s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
Paul's avatar
Paul committed
1072
        }
1073

1074
        return mm->add_instruction(op, make_contiguous(args[0]));
Paul's avatar
Paul committed
1075
1076
    }

Shucai Xiao's avatar
Shucai Xiao committed
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
    static const auto& get_nearest_op(const std::string& mode)
    {
        using nearest_op = std::function<std::size_t(std::size_t, double)>;
        static std::unordered_map<std::string, nearest_op> const nearest_ops = {
            {"round_prefer_floor",
             [=](std::size_t d_in, double val) {
                 val = std::max(0.0, std::min(d_in - 1.0, val));
                 return static_cast<std::size_t>(std::ceil((val - 0.5)));
             }},
            {"round_prefer_ceil",
             [=](std::size_t d_in, double val) {
                 val = std::max(0.0, std::min(d_in - 1.0, val));
                 return static_cast<std::size_t>(std::round((val)));
             }},
            {"floor",
             [=](std::size_t d_in, double val) {
                 val = std::max(0.0, std::min(d_in - 1.0, val));
                 return static_cast<std::size_t>(std::floor((val)));
             }},
            {"ceil", [=](std::size_t d_in, double val) {
                 val = std::max(0.0, std::min(d_in - 1.0, val));
                 return static_cast<std::size_t>(std::ceil((val)));
             }}};

        if(!contains(nearest_ops, mode))
        {
            MIGRAPHX_THROW("PARSE_RESIZE: nearest_mode " + mode + " not supported!");
        }

        return nearest_ops.at(mode);
    }

    static const auto& get_original_idx_op(const std::string& mode)
    {
        using original_idx_op =
            std::function<double(std::size_t, std::size_t, std::size_t, double)>;
        static std::unordered_map<std::string, original_idx_op> const idx_ops = {
            {"half_pixel",
             [=](std::size_t, std::size_t, std::size_t idx, double scale) {
                 return (idx + 0.5) / scale - 0.5;
             }},
            {"pytorch_half_pixel",
             [=](std::size_t, std::size_t l_out, std::size_t idx, double scale) {
                 return l_out > 1 ? (idx + 0.5) / scale - 0.5 : 0.0;
             }},
            {"align_corners",
             [=](std::size_t l_in, std::size_t l_out, std::size_t idx, double) {
                 return 1.0 * idx * (l_in - 1.0) / (l_out - 1.0);
             }},
            {"asymmetric",
             [=](std::size_t, std::size_t, std::size_t idx, double scale) { return idx / scale; }},
            {"tf_half_pixel_for_nn", [=](std::size_t, std::size_t, std::size_t idx, double scale) {
                 return (idx + 0.5) / scale;
             }}};

        if(!contains(idx_ops, mode))
        {
            MIGRAPHX_THROW("PARSE_RESIZE: coordinate_transformation_mode " + mode +
                           " not supported!");
        }

        return idx_ops.at(mode);
    }

    instruction_ref
    parse_resize(const std::string&, const node_info& info, std::vector<instruction_ref> args)
    {
        std::string coord_trans_mode = "half_pixel";
        if(contains(info.attributes, "coordinate_transformation_mode"))
        {
            coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s();
            // does not support transformation mode "tf_crop_and_resize"
            if(coord_trans_mode == "tf_crop_and_resize")
            {
                MIGRAPHX_THROW("PARSE_RESIZE: \"tf_crop_and_resize\" mode is not supported!");
            }
        }

        // mode: only nearest mode is supported for now
        if(contains(info.attributes, "mode"))
        {
            auto mode = info.attributes.at("mode").s();
            if(mode != "nearest")
            {
                MIGRAPHX_THROW("PARSE_RESIZE: only nearest mode is supported!");
            }
        }

        // nearest mode
        std::string nearest_mode = "round_prefer_floor";
        if(contains(info.attributes, "nearest_mode"))
        {
            nearest_mode = info.attributes.at("nearest_mode").s();
        }

        // check exclude_outside, only support 0
        if(contains(info.attributes, "exclude_outside"))
        {
            int exclude_outside = info.attributes.at("exclude_outside").i();
            if(exclude_outside == 1)
            {
                MIGRAPHX_THROW("PARSE_RESIZE: exclude_outside 1 is not supported!");
            }
        }

        // input data shape info
        auto in_s    = args[0]->get_shape();
        auto in_lens = in_s.lens();

        // output shape is explicitly specified
        std::vector<std::size_t> out_lens(in_lens.size());

        // scale
        std::vector<double> vec_scale;

        // output size is specified in input, so use it as output size
        if(args.size() == 4 and args.back()->name() != "undefined")
        {
            auto arg_out_s = args[3]->eval();
            check_arg_empty(arg_out_s, "PARSE_RESIZE: dynamic output size is not supported!");
            arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });

            if(out_lens.size() != in_lens.size())
            {
                MIGRAPHX_THROW("PARSE_RESIZE: specified output size does not match input size");
            }

            // compute the scale
            vec_scale.resize(in_lens.size());
            std::transform(in_lens.begin(),
                           in_lens.end(),
                           out_lens.begin(),
                           vec_scale.begin(),
                           [](auto iss, auto oss) { return 1.0 * oss / iss; });
        }
        // need to compute the output lens from input
        else
        {
            auto arg_scale = args[2]->eval();
            check_arg_empty(arg_scale, "PARSE_RESIZE: dynamic input scale is not supported!");

            arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
            if(in_lens.size() != vec_scale.size())
            {
                MIGRAPHX_THROW("PARSE_RESIZE: ranks of input and scale are different!");
            }

            std::transform(
                in_lens.begin(),
                in_lens.end(),
                vec_scale.begin(),
                out_lens.begin(),
                [&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
        }

        shape out_s{in_s.type(), out_lens};
        std::vector<int> ind(out_s.elements());

        // map out_idx to in_idx
        auto nearest_op = get_nearest_op(nearest_mode);
        auto idx_op     = get_original_idx_op(coord_trans_mode);

        shape_for_each(out_s, [&](auto idx) {
            auto in_idx = idx;
            for(auto ii = 0; ii < in_lens.size(); ++ii)
            {
                auto idx_val = idx_op(in_lens[ii], out_lens[ii], in_idx[ii], vec_scale[ii]);
                in_idx[ii]   = nearest_op(in_lens[ii], idx_val);
            }

            ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
        });

        // reshape input to one-dimension
        std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
        shape ind_s{shape::int32_type, out_lens};
1253
1254
1255
        auto rsp     = mm->add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
        auto ins_ind = mm->add_literal(literal(ind_s, ind));
        return mm->add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
Shucai Xiao's avatar
Shucai Xiao committed
1256
1257
    }

Shucai Xiao's avatar
Shucai Xiao committed
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
    instruction_ref
    parse_gather_elements(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        // standardize input data and index
        auto arg_data = make_contiguous(args[0]);
        auto arg_ind  = make_contiguous(args[1]);

        auto data_s = arg_data->get_shape();
        auto ind_s  = arg_ind->get_shape();

        if(data_s.lens().size() != ind_s.lens().size())
        {
            MIGRAPHX_THROW("PARSE_GATHER_ELEMENTS: input data and index must have the same rank!");
        }

        int n_rank     = static_cast<int>(data_s.lens().size());
        int tuned_axis = (axis < 0) ? (axis + n_rank) : axis;

        auto axis_stride      = data_s.strides()[tuned_axis];
        int64_t data_elem_num = static_cast<int64_t>(data_s.elements());
        // reshape the input data as one dimension and used as input data
        // to the gather operator
1286
        arg_data = mm->add_instruction(op::reshape{{data_elem_num}}, arg_data);
Shucai Xiao's avatar
Shucai Xiao committed
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303

        std::size_t elem_num = ind_s.elements();
        std::vector<int> ind_index(elem_num);
        std::iota(ind_index.begin(), ind_index.end(), 0);

        // convert index in input indices to that in input data
        std::vector<int> data_indices(elem_num);
        std::transform(ind_index.begin(), ind_index.end(), data_indices.begin(), [&](auto i) {
            return data_s.index(ind_s.multi(i));
        });

        std::vector<int> vec_axis_ind(elem_num);
        std::transform(ind_index.begin(), ind_index.end(), vec_axis_ind.begin(), [&](auto i) {
            return ind_s.multi(i)[tuned_axis];
        });

        auto l_shape_idx =
1304
1305
1306
1307
1308
1309
1310
            mm->add_literal(literal(ind_s, data_indices.begin(), data_indices.end()));
        auto l_dim_idx = mm->add_literal(literal(ind_s, vec_axis_ind.begin(), vec_axis_ind.end()));
        auto l_stride  = mm->add_literal(literal{{ind_s.type(), {1}}, {axis_stride}});
        l_stride       = mm->add_instruction(op::multibroadcast{ind_s.lens()}, l_stride);
        auto dim_diff  = mm->add_instruction(make_op("sub"), arg_ind, l_dim_idx);
        auto delta     = mm->add_instruction(make_op("mul"), dim_diff, l_stride);
        auto ind       = mm->add_instruction(make_op("add"), l_shape_idx, delta);
Shucai Xiao's avatar
Shucai Xiao committed
1311
1312

        op::gather op{0};
1313
        return mm->add_instruction(op, arg_data, ind);
Shucai Xiao's avatar
Shucai Xiao committed
1314
1315
    }

1316
    instruction_ref
1317
    parse_slice(const std::string&, node_info info, std::vector<instruction_ref> args)
1318
1319
    {
        op::slice op;
Shucai Xiao's avatar
Shucai Xiao committed
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341

        // slice can have up to 5 inputs, we first check the 5th one
        // to decide whether MIGRAPHX can handle this slice
        if(args.size() == 5)
        {
            migraphx::argument step_arg = args.back()->eval();
            check_arg_empty(step_arg, "PARSE_SLICE: cannot handle variable steps for slice");
            std::vector<int> steps;
            step_arg.visit([&](auto s) { steps.assign(s.begin(), s.end()); });
            if(!std::all_of(steps.begin(), steps.end(), [](auto s) { return s == 1; }))
            {
                MIGRAPHX_THROW("PARSE_SLICE: cannot handle step other than 1");
            }
        }

        if(args.size() >= 4)
        {
            migraphx::argument axes_arg = args.at(3)->eval();
            check_arg_empty(axes_arg, "PARSE_SLICE: cannot handle variable axes for slice");
            axes_arg.visit([&](auto s) { op.axes.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "axes"))
1342
        {
1343
            literal s = parse_value(info.attributes.at("axes"));
1344
1345
            s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
1346
1347

        if(args.size() >= 3)
Khalique's avatar
Khalique committed
1348
        {
Shucai Xiao's avatar
Shucai Xiao committed
1349
1350
1351
            migraphx::argument end_arg = args.at(2)->eval();
            check_arg_empty(end_arg, "PARSE_SLICE: cannot handle variable ends for slice");
            end_arg.visit([&](auto s) { op.ends.assign(s.begin(), s.end()); });
Khalique's avatar
Khalique committed
1352
        }
Shucai Xiao's avatar
Shucai Xiao committed
1353
        else if(contains(info.attributes, "ends"))
1354
        {
1355
1356
            literal s = parse_value(info.attributes.at("ends"));
            s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
1357
        }
Shucai Xiao's avatar
Shucai Xiao committed
1358
1359
1360
1361
1362
1363
1364
1365

        if(args.size() >= 2)
        {
            migraphx::argument start_arg = args.at(1)->eval();
            check_arg_empty(start_arg, "PARSE_SLICE: cannot handle variable starts for slice");
            start_arg.visit([&](auto s) { op.starts.assign(s.begin(), s.end()); });
        }
        else if(contains(info.attributes, "starts"))
1366
        {
1367
            literal s = parse_value(info.attributes.at("starts"));
1368
1369
            s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
        }
Shucai Xiao's avatar
Shucai Xiao committed
1370

kahmed10's avatar
kahmed10 committed
1371
1372
1373
1374
1375
1376
1377
        if(op.axes.empty())
        {
            std::vector<int64_t> axes(args[0]->get_shape().lens().size());
            std::iota(axes.begin(), axes.end(), int64_t{0});
            op.axes = axes;
        }

1378
        return mm->add_instruction(op, args[0]);
1379
1380
    }

1381
    instruction_ref
1382
    parse_constant(const std::string&, node_info info, const std::vector<instruction_ref>&) const
Paul's avatar
Paul committed
1383
    {
1384
        literal v = parse_value(info.attributes.at("value"));
1385
        // return empty literal
Shucai Xiao's avatar
Shucai Xiao committed
1386
        if(v.get_shape().elements() == 0)
1387
        {
1388
            return mm->add_literal(literal{});
1389
1390
        }

1391
        auto dim_size = info.attributes.at("value").t().dims_size();
1392
1393
        // if dim_size is 0, it is a scalar
        if(dim_size == 0)
1394
        {
1395
            migraphx::shape scalar_shape{v.get_shape().type()};
1396
            return mm->add_literal(migraphx::literal{scalar_shape, v.data()});
1397
1398
        }

1399
        return mm->add_literal(v);
Paul's avatar
Paul committed
1400
    }
Paul's avatar
Paul committed
1401

Paul's avatar
Paul committed
1402
    instruction_ref
1403
    parse_gemm(const std::string&, node_info info, std::vector<instruction_ref> args) const
Paul's avatar
Paul committed
1404
1405
    {
        float alpha = 1.0f;
Khalique's avatar
Khalique committed
1406
        float beta  = 1.0f;
Paul's avatar
Paul committed
1407
1408
        bool transa = false;
        bool transb = false;
1409
        if(contains(info.attributes, "alpha"))
Paul's avatar
Paul committed
1410
        {
1411
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Paul's avatar
Paul committed
1412
        }
1413
        if(contains(info.attributes, "beta"))
Paul's avatar
Paul committed
1414
        {
1415
            beta = parse_value(info.attributes.at("beta")).at<float>();
Paul's avatar
Paul committed
1416
        }
1417
        if(contains(info.attributes, "transA"))
Paul's avatar
Paul committed
1418
        {
1419
            transa = parse_value(info.attributes.at("transA")).at<bool>();
Paul's avatar
Paul committed
1420
        }
1421
        if(contains(info.attributes, "transB"))
Paul's avatar
Paul committed
1422
        {
1423
            transb = parse_value(info.attributes.at("transB")).at<bool>();
Paul's avatar
Paul committed
1424
        }
1425
1426
1427
1428
1429
1430

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
        std::swap(*perm.rbegin(), *(perm.rbegin() + 1));

1431
1432
        auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
Paul's avatar
Paul committed
1433
1434
        if(args.size() == 3)
        {
1435
            if(beta != 0.f && args[2]->get_shape().elements() > 0)
1436
            {
Shucai Xiao's avatar
Shucai Xiao committed
1437
                auto out_lens   = l1->get_shape().lens();
1438
                out_lens.back() = l2->get_shape().lens().back();
Shucai Xiao's avatar
Shucai Xiao committed
1439
                auto l3         = args[2];
Shucai Xiao's avatar
Shucai Xiao committed
1440
1441
                auto l3_lens    = l3->get_shape().lens();
                if(!std::equal(out_lens.begin(), out_lens.end(), l3_lens.begin(), l3_lens.end()))
Khalique's avatar
Khalique committed
1442
                {
1443
                    l3 = mm->add_instruction(op::multibroadcast{out_lens}, args[2]);
Khalique's avatar
Khalique committed
1444
                }
1445
                return mm->add_instruction(
1446
                    make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3);
1447
            }
Paul's avatar
Paul committed
1448
        }
1449

1450
        return mm->add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2);
Paul's avatar
Paul committed
1451
1452
    }

1453
1454
1455
1456
    instruction_ref parse_matmul(const std::string&,
                                 const std::string& op_name,
                                 const node_info&,
                                 std::vector<instruction_ref> args)
1457
    {
Shucai Xiao's avatar
Shucai Xiao committed
1458
1459
        auto l0      = args[0];
        auto l1      = args[1];
1460
1461
1462
1463
1464
        auto l0_lens = l0->get_shape().lens();
        auto l1_lens = l1->get_shape().lens();

        // args[0] is a vector, prepend 1 to the shape
        bool is_a_prepended = false;
Shucai Xiao's avatar
Shucai Xiao committed
1465
        if(l0_lens.size() == 1)
1466
1467
1468
        {
            is_a_prepended = true;
            l0_lens.insert(l0_lens.begin(), 1);
1469
            l0 = mm->add_instruction(op::unsqueeze{{0}}, args[0]);
1470
1471
1472
        }

        bool is_b_appended = false;
Shucai Xiao's avatar
Shucai Xiao committed
1473
        if(l1_lens.size() == 1)
1474
1475
1476
        {
            is_b_appended = true;
            l1_lens.push_back(1);
1477
            l1 = mm->add_instruction(op::unsqueeze{{1}}, args[1]);
1478
1479
1480
1481
        }

        instruction_ref bl0 = l0;
        instruction_ref bl1 = l1;
Shucai Xiao's avatar
Shucai Xiao committed
1482
        if(!std::equal(l0_lens.rbegin() + 2, l0_lens.rend(), l1_lens.rbegin() + 2, l1_lens.rend()))
1483
1484
1485
1486
1487
1488
        {
            auto l0_it = l0_lens.begin() + l0_lens.size() - 2;
            std::vector<std::size_t> l0_broadcasted_lens(l0_lens.begin(), l0_it);
            auto l1_it = l1_lens.begin() + l1_lens.size() - 2;
            std::vector<std::size_t> l1_broadcasted_lens(l1_lens.begin(), l1_it);
            auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens);
1489
            l0_broadcasted_lens = output_lens;
1490
            l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, l0_lens.end());
1491
            l1_broadcasted_lens = output_lens;
1492
            l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, l1_lens.end());
Shucai Xiao's avatar
Shucai Xiao committed
1493
            if(l0_lens != l0_broadcasted_lens)
1494
            {
1495
                bl0 = mm->add_instruction(op::multibroadcast{l0_broadcasted_lens}, l0);
1496
            }
Shucai Xiao's avatar
Shucai Xiao committed
1497
            if(l1_lens != l1_broadcasted_lens)
1498
            {
1499
                bl1 = mm->add_instruction(op::multibroadcast{l1_broadcasted_lens}, l1);
1500
1501
1502
            }
        }

1503
        auto dot_res = mm->add_instruction(make_op(op_name, {{"alpha", 1}, {"beta", 0}}), bl0, bl1);
1504
        int64_t num_axis = static_cast<int64_t>(dot_res->get_shape().lens().size());
Shucai Xiao's avatar
Shucai Xiao committed
1505
        if(is_a_prepended)
1506
        {
1507
            dot_res = mm->add_instruction(op::squeeze{{num_axis - 2}}, dot_res);
1508
1509
            --num_axis;
        }
Shucai Xiao's avatar
Shucai Xiao committed
1510
        if(is_b_appended)
1511
        {
1512
            dot_res = mm->add_instruction(op::squeeze{{num_axis - 1}}, dot_res);
1513
        }
Shucai Xiao's avatar
Shucai Xiao committed
1514

1515
1516
1517
        return dot_res;
    }

1518
    instruction_ref
1519
    parse_batchnorm(const std::string&, node_info info, std::vector<instruction_ref> args) const
1520
    {
Scott Thornton's avatar
Scott Thornton committed
1521
1522
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
1523
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
1524
        if(contains(info.attributes, "epsilon"))
1525
        {
1526
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
1527
        }
1528
        if(contains(info.attributes, "momentum"))
1529
        {
1530
            momentum = parse_value(info.attributes.at("momentum")).at<float>();
1531
        }
1532
        if(contains(info.attributes, "spatial"))
1533
        {
1534
            bn_mode = (parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0)
1535
1536
                          ? op::batch_norm_inference::spatial
                          : op::batch_norm_inference::per_activation;
1537
        }
Paul's avatar
Paul committed
1538
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
1539
        return mm->add_instruction(op, std::move(args));
1540
1541
    }

1542
    instruction_ref
1543
    parse_instancenorm(const std::string&, node_info info, std::vector<instruction_ref> args) const
kahmed10's avatar
kahmed10 committed
1544
1545
    {
        // y = scale * ( x - mean ) / sqrt ( variance + epsilon ) + bias
kahmed10's avatar
kahmed10 committed
1546
1547
        // mean = reduce_mean({D1, D2, ... Dk}, x)
        // variance = reduce_mean({D1, D2, ... Dk}, (x - mean)^2)
kahmed10's avatar
kahmed10 committed
1548
1549

        float epsilon = 1e-5f;
1550
        if(contains(info.attributes, "epsilon"))
kahmed10's avatar
kahmed10 committed
1551
        {
1552
            epsilon = parse_value(info.attributes.at("epsilon")).at<float>();
kahmed10's avatar
kahmed10 committed
1553
1554
1555
1556
1557
        }
        auto x     = args[0];
        auto scale = args[1];
        auto bias  = args[2];
        auto dims  = x->get_shape().lens();
kahmed10's avatar
kahmed10 committed
1558
1559
1560
        auto ndims = dims.size();
        assert(ndims >= 2);
        auto kdims = ndims - 2;
kahmed10's avatar
kahmed10 committed
1561

kahmed10's avatar
kahmed10 committed
1562
1563
1564
        std::vector<int64_t> axes(kdims);
        std::iota(axes.begin(), axes.end(), 2);

1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
        auto mean            = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), x);
        auto mean_bcast      = mm->add_instruction(op::multibroadcast{dims}, mean);
        auto l0              = mm->add_instruction(make_op("sqdiff"), x, mean_bcast);
        auto variance        = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), l0);
        auto l1              = mm->add_instruction(make_op("sub"), x, mean_bcast);
        auto epsilon_literal = mm->add_literal(epsilon);
        auto epsilon_bcast   = mm->add_instruction(op::multibroadcast{dims}, epsilon_literal);
        auto variance_bcast  = mm->add_instruction(op::multibroadcast{dims}, variance);
        auto l2              = mm->add_instruction(make_op("add"), variance_bcast, epsilon_bcast);
        auto l3              = mm->add_instruction(make_op("rsqrt"), l2);
        auto l4              = mm->add_instruction(make_op("mul"), l1, l3);
        auto scale_bcast     = mm->add_instruction(op::broadcast{1, dims}, scale);
kahmed10's avatar
kahmed10 committed
1577
        ;
1578
1579
1580
        auto bias_bcast = mm->add_instruction(op::broadcast{1, dims}, bias);
        auto l5         = mm->add_instruction(make_op("mul"), l4, scale_bcast);
        return mm->add_instruction(make_op("add"), l5, bias_bcast);
kahmed10's avatar
kahmed10 committed
1581
1582
    }

1583
    instruction_ref
1584
    parse_leaky_relu(const std::string&, node_info info, std::vector<instruction_ref> args) const
1585
    {
Khalique's avatar
Khalique committed
1586
        float alpha = 0.01; // default alpha val for leaky relu
1587
        if(contains(info.attributes, "alpha"))
1588
        {
1589
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
1590
        }
1591
        auto op = make_op("leaky_relu", {{"alpha", alpha}});
1592
        return mm->add_instruction(op, args.front());
1593
1594
    }

1595
1596
    instruction_ref
    parse_elu(const std::string&, node_info info, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
1597
1598
    {
        float alpha = 1.0; // default alpha val for elu
1599
        if(contains(info.attributes, "alpha"))
Khalique's avatar
Khalique committed
1600
        {
1601
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
Khalique's avatar
Khalique committed
1602
        }
1603
        auto op = make_op("elu", {{"alpha", alpha}});
1604
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
1605
1606
    }

1607
1608
    instruction_ref
    parse_lrn(const std::string&, node_info info, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
1609
1610
    {
        float alpha = 0.0001;
Khalique's avatar
Khalique committed
1611
1612
1613
        float beta  = 0.75;
        float bias  = 1.0;
        int size    = 1;
1614
1615
1616
1617
1618
1619
1620
1621
        if(contains(info.attributes, "alpha"))
            alpha = parse_value(info.attributes.at("alpha")).at<float>();
        if(contains(info.attributes, "beta"))
            beta = parse_value(info.attributes.at("beta")).at<float>();
        if(contains(info.attributes, "bias"))
            bias = parse_value(info.attributes.at("bias")).at<float>();
        if(contains(info.attributes, "size"))
            size = parse_value(info.attributes.at("size")).at<int>();
Khalique's avatar
Khalique committed
1622
        op::lrn op{alpha, beta, bias, size};
1623
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
1624
1625
    }

1626
    instruction_ref
1627
    parse_imagescaler(const std::string&, node_info info, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
1628
1629
1630
    {
        float scale = 1.0;
        std::vector<float> bias{};
1631
        if(contains(info.attributes, "scale"))
Khalique's avatar
Khalique committed
1632
        {
1633
            scale = parse_value(info.attributes.at("scale")).at<float>();
Khalique's avatar
Khalique committed
1634
1635
        }

1636
        if(contains(info.attributes, "bias"))
Khalique's avatar
Khalique committed
1637
        {
1638
            auto&& bias_floats = info.attributes["bias"].floats();
Khalique's avatar
Khalique committed
1639
1640
            bias               = std::vector<float>(bias_floats.begin(), bias_floats.end());
        }
Shucai Xiao's avatar
Shucai Xiao committed
1641
1642
1643
        auto input_shape       = args.front()->get_shape();
        auto const& input_lens = input_shape.lens();
        auto input_type        = input_shape.type();
Khalique's avatar
Khalique committed
1644

1645
1646
        auto scale_val = mm->add_literal(literal{shape{input_type}, {scale}});
        auto bias_vals = mm->add_literal(literal{shape{input_type, {bias.size()}}, bias});
Khalique's avatar
Khalique committed
1647

1648
1649
1650
1651
        auto scale_tensor = mm->add_instruction(migraphx::op::scalar{input_lens}, scale_val);
        auto img_scaled = mm->add_instruction(migraphx::make_op("mul"), args.front(), scale_tensor);
        auto bias_bcast = mm->add_instruction(migraphx::op::broadcast{1, input_lens}, bias_vals);
        return mm->add_instruction(migraphx::make_op("add"), img_scaled, bias_bcast);
Khalique's avatar
Khalique committed
1652
    }
Khalique's avatar
Khalique committed
1653

Khalique's avatar
Khalique committed
1654
    instruction_ref
1655
    parse_transpose(const std::string&, node_info info, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
1656
1657
    {
        std::vector<int64_t> perm{};
1658
        if(contains(info.attributes, "perm"))
Khalique's avatar
Khalique committed
1659
        {
1660
            auto&& perm_vals = info.attributes["perm"].ints();
Khalique's avatar
Khalique committed
1661
1662
            perm             = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
        }
1663
        return mm->add_instruction(migraphx::op::transpose{perm}, args.front());
Khalique's avatar
Khalique committed
1664
1665
    }

1666
    instruction_ref parse_pad(const std::string&, node_info info, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
1667
1668
    {
        std::vector<int64_t> pads{};
1669
1670
1671
1672
1673
1674
1675
        if(args.size() >= 2)
        {
            auto pad_arg = args.at(1)->eval();
            check_arg_empty(pad_arg, "PARSE_PAD: pad input must be constant");
            pad_arg.visit([&](auto v) { pads.assign(v.begin(), v.end()); });
        }
        else if(contains(info.attributes, "pads"))
Khalique's avatar
Khalique committed
1676
        {
1677
            auto&& pad_vals = info.attributes["pads"].ints();
Khalique's avatar
Khalique committed
1678
1679
            pads            = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
        }
1680
1681
1682
1683
1684
        else
        {
            MIGRAPHX_THROW("PARSE_PAD: pad must be available");
        }

1685
        // check if padding is actually being done (at least one value is nonzero)
Khalique's avatar
Khalique committed
1686
        if(std::all_of(pads.begin(), pads.end(), [](const int& i) { return i == 0; }))
1687
        {
1688
            return mm->add_instruction(make_op("identity"), args.front());
1689
        }
1690

kahmed10's avatar
kahmed10 committed
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
        if(contains(info.attributes, "mode"))
        {
            auto mode = info.attributes.at("mode").s();
            if(mode == "reflect")
                return reflect_pad(pads, args.front());
            if(mode != "constant")
            {
                MIGRAPHX_THROW(
                    "PARSE_PAD: migraphx currently only supports constant and reflect padding");
            }
        }

1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
        float value = 0.0f;
        // third input is the value
        if(args.size() == 3)
        {
            auto val_ins = args.at(2);
            if(!val_ins->can_eval())
            {
                MIGRAPHX_THROW("PARSE_PAD: input value must be constant");
            }
            auto val_arg = val_ins->eval();
            if(val_arg.get_shape().elements() != 1)
            {
                MIGRAPHX_THROW("PARSE_PAD: value should contain only one element");
            }
            value = val_arg.at<float>();
        }
        else if(contains(info.attributes, "value"))
Khalique's avatar
Khalique committed
1720
        {
1721
            value = parse_value(info.attributes.at("value")).at<float>();
Khalique's avatar
Khalique committed
1722
        }
1723

1724
        return mm->add_instruction(migraphx::op::pad{pads, value}, args.front());
Khalique's avatar
Khalique committed
1725
    }
Shucai Xiao's avatar
Shucai Xiao committed
1726
1727

    instruction_ref
1728
    parse_selu(const std::string&, const node_info& info, std::vector<instruction_ref> args) const
Shucai Xiao's avatar
Shucai Xiao committed
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
    {
        auto type   = args[0]->get_shape().type();
        auto lens   = args[0]->get_shape().lens();
        float alpha = 1.67326f;
        if(contains(info.attributes, "alpha"))
        {
            alpha = info.attributes.at("alpha").f();
        }

        float gamma = 1.0507f;
        if(contains(info.attributes, "gamma"))
        {
            gamma = info.attributes.at("gamma").f();
        }

1744
1745
        auto l_alpha = mm->add_literal({{type, {1}}, {alpha}});
        auto l_gamma = mm->add_literal({{type, {1}}, {gamma / 2.0f}});
Shucai Xiao's avatar
Shucai Xiao committed
1746
1747
1748
        if(lens != std::vector<std::size_t>{1})
        {
            l_alpha =
1749
                mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_alpha);
Shucai Xiao's avatar
Shucai Xiao committed
1750
            l_gamma =
1751
                mm->add_instruction(make_op("multibroadcast", {{"output_lens", lens}}), l_gamma);
Shucai Xiao's avatar
Shucai Xiao committed
1752
1753
        }

1754
1755
        auto sign_x = mm->add_instruction(make_op("sign"), args[0]);
        auto exp_x  = mm->add_instruction(make_op("exp"), args[0]);
Shucai Xiao's avatar
Shucai Xiao committed
1756

1757
1758
        auto alpha_ex  = mm->add_instruction(make_op("mul"), l_alpha, exp_x);
        auto aex_alpha = mm->add_instruction(make_op("sub"), alpha_ex, l_alpha);
Shucai Xiao's avatar
Shucai Xiao committed
1759

1760
1761
        auto ins1 = mm->add_instruction(make_op("add"), aex_alpha, args[0]);
        auto ins2 = mm->add_instruction(make_op("sub"), aex_alpha, args[0]);
Shucai Xiao's avatar
Shucai Xiao committed
1762

1763
1764
        auto sign2   = mm->add_instruction(make_op("mul"), sign_x, ins2);
        auto ins_sub = mm->add_instruction(make_op("sub"), ins1, sign2);
Shucai Xiao's avatar
Shucai Xiao committed
1765

1766
        return mm->add_instruction(make_op("mul"), ins_sub, l_gamma);
Shucai Xiao's avatar
Shucai Xiao committed
1767
1768
    }

1769
1770
1771
    // Use a literal instruction to replace the shape since, output of
    // shape operator are literals in migraphx
    instruction_ref
1772
    parse_shape(const std::string&, const node_info&, std::vector<instruction_ref> args) const
1773
1774
    {
        if(args.size() != 1)
1775
            MIGRAPHX_THROW("Shape: operator should have 1 operand");
1776
1777
1778
1779
1780
1781
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int64_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
        std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
            return int64_t(i);
        });
1782
        return mm->add_literal(migraphx::literal{s, vec_shape});
1783
1784
1785
1786
1787
    }

    // Use a literal instruction to replace the constantFill operator. In RNN, input shape
    // and value are fixed, so no need to do the actual computation for the constantFill
    // operator
1788
1789
    instruction_ref
    parse_constant_fill(const std::string&, node_info info, std::vector<instruction_ref> args)
1790
1791
1792
1793
1794
    {
        int input_as_shape = 0;
        int dtype          = 1;
        float value        = 0.0f;

1795
        if(contains(info.attributes, "dtype"))
1796
        {
1797
            dtype = parse_value(info.attributes.at("dtype")).at<int>();
1798
        }
Shucai Xiao's avatar
Shucai Xiao committed
1799
        shape::type_t type = get_type(dtype);
1800

1801
        if(contains(info.attributes, "input_as_shape"))
1802
        {
1803
            input_as_shape = parse_value(info.attributes.at("input_as_shape")).at<int>();
1804
1805
        }

1806
        if(contains(info.attributes, "value"))
1807
        {
1808
            value = parse_value(info.attributes.at("value")).at<float>();
1809
1810
        }

1811
        if(contains(info.attributes, "extra_shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1812
        {
1813
            MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
1814
1815
        }

1816
1817
        if(input_as_shape == 1)
        {
Shucai Xiao's avatar
Shucai Xiao committed
1818
            if(args.size() != 1)
1819
            {
1820
                MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
1821
1822
            }

1823
            if(contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1824
            {
1825
                MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
Shucai Xiao's avatar
Shucai Xiao committed
1826
                               "at the same time");
1827
1828
            }

1829
            migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1830
            check_arg_empty(in, "ConstantFill: dynamic shape is not supported");
1831

1832
1833
1834
            std::vector<std::size_t> dims;
            in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
            migraphx::shape s(type, dims);
1835
            std::vector<float> values(s.elements(), value);
1836
            return mm->add_literal(migraphx::literal(s, values));
1837
1838
1839
        }
        else if(input_as_shape == 0)
        {
1840
            if(!contains(info.attributes, "shape"))
Shucai Xiao's avatar
Shucai Xiao committed
1841
            {
1842
                MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
1843
1844
            }

1845
            literal ls = parse_value(info.attributes.at("shape"));
1846
            std::vector<std::size_t> dims;
Shucai Xiao's avatar
Shucai Xiao committed
1847
            ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
1848
            migraphx::shape s{type, dims};
1849
            std::vector<float> values(s.elements(), value);
1850
            return mm->add_literal(migraphx::literal(s, values));
1851
1852
1853
        }
        else
        {
1854
            MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
1855
1856
1857
        }
    }

1858
1859
    instruction_ref
    parse_constant_of_shape(const std::string&, node_info info, std::vector<instruction_ref> args)
1860
1861
    {
        literal l_val{};
1862
        if(contains(info.attributes, "value"))
1863
        {
1864
            l_val = parse_value(info.attributes.at("value"));
Shucai Xiao's avatar
Shucai Xiao committed
1865
            if(l_val.get_shape().elements() != 1)
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
            {
                MIGRAPHX_THROW("ConstantOfShape: attribute value can contain only 1 elements!");
            }
        }
        else
        {
            l_val = literal({shape::float_type, {1}, {0}}, {0.0f});
        }

        // input is empty, output is a scalar
        auto type = l_val.get_shape().type();
1877

Shucai Xiao's avatar
Shucai Xiao committed
1878
        if(args.empty())
1879
        {
Shucai Xiao's avatar
Shucai Xiao committed
1880
            MIGRAPHX_THROW("ConstantOfShape : must have 1 input!");
1881
1882
1883
        }
        else
        {
1884
1885
            migraphx::shape s;
            // empty input tensor, output is a scalar
Shucai Xiao's avatar
Shucai Xiao committed
1886
            if(args[0]->get_shape().elements() == 0)
1887
            {
1888
                s = migraphx::shape{type, {1}, {0}};
1889
            }
1890
1891
1892
            else
            {
                migraphx::argument in = args[0]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1893
                check_arg_empty(in, "ConstantOfShape: dynamic shape is not supported");
1894

1895
1896
1897
1898
                std::vector<std::size_t> dims;
                in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
                s = migraphx::shape{type, dims};
            }
1899

Shucai Xiao's avatar
Shucai Xiao committed
1900
            literal l_out{};
1901
            l_val.visit([&](auto val) {
Shucai Xiao's avatar
Shucai Xiao committed
1902
                using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
1903
                // l_val contains only one element
1904
                std::vector<val_type> out_vec(s.elements(), val.front());
1905
1906
1907
                l_out = literal(s, out_vec);
            });

1908
            return mm->add_literal(l_out);
1909
1910
1911
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
1912
    instruction_ref
1913
    parse_expand(const std::string&, const node_info&, std::vector<instruction_ref> args)
1914
    {
Shucai Xiao's avatar
Shucai Xiao committed
1915
        auto in_lens             = args[0]->get_shape().lens();
1916
        migraphx::argument arg_s = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
1917
        check_arg_empty(arg_s, "Expand: dynamic shape is not supported");
1918
1919
1920
        std::vector<std::size_t> dims;
        arg_s.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
        auto out_lens = compute_broadcasted_lens(in_lens, dims);
1921
        return mm->add_instruction(op::multibroadcast{out_lens}, args[0]);
1922
1923
    }

Shucai Xiao's avatar
Shucai Xiao committed
1924
    std::vector<instruction_ref>
1925
    parse_rnn(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
1926
1927
    {
        migraphx::shape input_shape = args[0]->get_shape();
1928
        std::size_t hidden_size     = args[1]->get_shape().lens()[1];
Shucai Xiao's avatar
Shucai Xiao committed
1929

1930
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
1931
        {
1932
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
1933
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
1934
1935
1936
            {
                MIGRAPHX_THROW("RNN: hidden size mismatch in input and attribute");
            }
Shucai Xiao's avatar
Shucai Xiao committed
1937
1938
1939
1940
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
1941
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
1942
        {
1943
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
1944
1945
        }

1946
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
1947
1948
        if(direction == "bidirectional")
        {
1949
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
1950
1951
1952
        }
        else if(direction == "reverse")
        {
1953
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
1954
1955
        }

1956
        std::vector<std::string> vec_names{"tanh"};
1957
        if(contains(info.attributes, "activations"))
1958
        {
1959
            auto names = info.attributes.at("activations").strings();
1960
            vec_names.clear();
1961
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
1962
1963
1964
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
1965
1966
        }

1967
1968
1969
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
1970
        if(name_it != vec_names.end())
1971
1972
1973
        {
            MIGRAPHX_THROW("RNN: activation function " + std::string(*name_it) + " not supported");
        }
1974

Shucai Xiao's avatar
Shucai Xiao committed
1975
        // bidirectional case should have two activation functions.
Shucai Xiao's avatar
Shucai Xiao committed
1976
        // one is for forward, and the other is for reverse.
Shucai Xiao's avatar
Shucai Xiao committed
1977
        // if only one actv function is provided, we use it in both
1978
        // forward and reverse direction
1979
        if(dirct == op::rnn_direction::bidirectional)
1980
        {
Shucai Xiao's avatar
Shucai Xiao committed
1981
            if(vec_names.size() == 1)
1982
1983
1984
1985
1986
            {
                vec_names.push_back(vec_names.at(0));
            }
        }

Shucai Xiao's avatar
Shucai Xiao committed
1987
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
1988
1989
1990
1991
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& fn) { return map_actv_funcs[fn]; });
Shucai Xiao's avatar
Shucai Xiao committed
1992

Shucai Xiao's avatar
Shucai Xiao committed
1993
1994
        // To be added later
        float clip = 0.0;
1995
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
1996
        {
1997
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
1998
1999
        }

2000
2001
        // if the number of arguments is less than 6, append
        // undefined operator to have 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
2002
        if(args.size() < 6)
2003
        {
2004
            auto ins = mm->add_instruction(op::undefined{});
2005
2006
2007
            args.insert(args.end(), (6 - args.size()), ins);
        }

Shucai Xiao's avatar
Shucai Xiao committed
2008
        // first output for the concatenation of hidden states
2009
2010
        auto hidden_states =
            mm->add_instruction(op::rnn{hidden_size, vec_actv_funcs, dirct, clip}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
2011

2012
        // second output for the last hidden state
2013
        auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
2014

Shucai Xiao's avatar
Shucai Xiao committed
2015
        return {hidden_states, last_output};
Shucai Xiao's avatar
Shucai Xiao committed
2016
2017
    }

2018
    std::vector<instruction_ref>
2019
    parse_gru(const std::string&, node_info info, std::vector<instruction_ref> args)
2020
2021
2022
2023
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

2024
        if(contains(info.attributes, "hidden_size"))
2025
        {
2026
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
2027
            if(hidden_size != hidden_size_att)
Shucai Xiao's avatar
Shucai Xiao committed
2028
2029
2030
            {
                MIGRAPHX_THROW("GRU: hidden size mismatch in input and attribute");
            }
2031
2032
2033
2034
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
2035
        if(contains(info.attributes, "direction"))
2036
        {
2037
            direction = info.attributes.at("direction").s();
2038
2039
        }

2040
        op::rnn_direction dirct = op::rnn_direction::forward;
2041
2042
        if(direction == "bidirectional")
        {
2043
            dirct = op::rnn_direction::bidirectional;
2044
2045
2046
        }
        else if(direction == "reverse")
        {
2047
            dirct = op::rnn_direction::reverse;
2048
2049
        }

2050
        std::vector<std::string> vec_names = {"sigmoid", "tanh"};
2051
        if(contains(info.attributes, "activations"))
2052
        {
2053
            auto names = info.attributes.at("activations").strings();
2054
            vec_names.clear();
Shucai Xiao's avatar
Shucai Xiao committed
2055
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
2056
2057
2058
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
2059
2060
        }

2061
        // need 4 activation functions
2062
        if(dirct == op::rnn_direction::bidirectional)
2063
        {
Shucai Xiao's avatar
Shucai Xiao committed
2064
            // 4 activation functions are used in the bidirectional
2065
            // scenario. No spec is provided in onnx::operator. we
Shucai Xiao's avatar
Shucai Xiao committed
2066
2067
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1 four times. If 2 actv functins are provided,
2068
2069
            // assume forward and reverse use the same pair of actv
            // functions. For the case of 3 actv functions provided,
Shucai Xiao's avatar
Shucai Xiao committed
2070
2071
2072
            // assume the 3rd one is repeated once and used by the
            // reverse direction.
            // This may need change later
2073
            if(vec_names.size() == 1)
2074
            {
2075
                vec_names.insert(vec_names.end(), 3, vec_names.at(0));
2076
            }
2077
            else if(vec_names.size() == 2)
2078
            {
2079
2080
2081
                // repeat the activation functions
                vec_names.push_back(vec_names.at(0));
                vec_names.push_back(vec_names.at(1));
2082
            }
2083
            else if(vec_names.size() == 3)
2084
            {
2085
                vec_names.push_back(vec_names.at(2));
2086
2087
            }
        }
Shucai Xiao's avatar
Shucai Xiao committed
2088
        else
2089
        {
2090
            if(vec_names.size() == 1)
2091
            {
2092
                vec_names.push_back(vec_names.at(0));
2093
2094
2095
            }
        }

2096
2097
2098
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
2099
        if(name_it != vec_names.end())
2100
2101
2102
        {
            MIGRAPHX_THROW("GRU: activation function " + std::string(*name_it) + " not supported");
        }
2103

Shucai Xiao's avatar
Shucai Xiao committed
2104
        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
2105
2106
2107
2108
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
2109
2110

        float clip = 0.0;
2111
        if(contains(info.attributes, "clip"))
2112
        {
2113
            clip = parse_value(info.attributes.at("clip")).at<float>();
2114
2115
2116
        }

        int linear_before_reset = 0;
2117
        if(contains(info.attributes, "linear_before_reset"))
2118
        {
2119
            linear_before_reset = parse_value(info.attributes.at("linear_before_reset")).at<int>();
2120
2121
        }

Shucai Xiao's avatar
Shucai Xiao committed
2122
        // append undefined opeator to make 6 arguments
Shucai Xiao's avatar
Shucai Xiao committed
2123
        if(args.size() < 6)
Shucai Xiao's avatar
Shucai Xiao committed
2124
        {
2125
            auto ins = mm->add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
2126
2127
2128
            args.insert(args.end(), 6 - args.size(), ins);
        }

2129
        // first output for concatenation of hidden states
2130
        auto hidden_states = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
2131
            op::gru{hidden_size, vec_actv_funcs, dirct, clip, linear_before_reset},
Shucai Xiao's avatar
Shucai Xiao committed
2132
            std::move(args));
2133
2134

        // second output for last gru output
2135
        auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
2136

Shucai Xiao's avatar
Shucai Xiao committed
2137
        return {hidden_states, last_output};
2138
2139
    }

Shucai Xiao's avatar
Shucai Xiao committed
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
    void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv_func_names)
    {
        // need 6 activation functions for bidirectional directions
        if(dirct == op::rnn_direction::bidirectional)
        {
            // 6 activation functions are used in the bidirectional
            // scenario. No spec is provided in onnx::operator. we
            // use the algorithm that: if 1 actv function is provided,
            // repeat 1st six times. If 2 actv functins are provided,
            // repeat 2nd once, then repeat all three once
            // if 3 actv funcs are provide, repeat all three once.
            // the same algorithm is used for 4, 5, and 6 actv funcions
            // provided. This may need change later
            switch(actv_func_names.size())
            {
            case 1:
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(0),
                                   actv_func_names.at(0),
                                   actv_func_names.at(0),
                                   actv_func_names.at(0),
                                   actv_func_names.at(0)};
                break;

            case 2:
                // repeat the 2nd actv func once, then repeat all three another time
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(1),
                                   actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(1)};
                break;

            case 3:
                // repeat all three actv funcs once
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(2),
                                   actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(2)};
                break;

            case 4:
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(2),
                                   actv_func_names.at(3),
                                   actv_func_names.at(3),
                                   actv_func_names.at(3)};
                break;

            case 5:
                actv_func_names = {actv_func_names.at(0),
                                   actv_func_names.at(1),
                                   actv_func_names.at(2),
                                   actv_func_names.at(3),
                                   actv_func_names.at(4),
                                   actv_func_names.at(4)};
                break;

            default: break;
            }
        }
        else
        {
            switch(actv_func_names.size())
            {
            case 1:
                actv_func_names = {
                    actv_func_names.at(0), actv_func_names.at(0), actv_func_names.at(0)};
                break;

            case 2:
                // repeat the 2nd actv func once, so we have 3 actv funcs
                actv_func_names = {
                    actv_func_names.at(0), actv_func_names.at(1), actv_func_names.at(1)};
                break;

            default: break;
            }
        }
    }

Shucai Xiao's avatar
Shucai Xiao committed
2225
    std::vector<instruction_ref>
2226
    parse_lstm(const std::string&, node_info info, std::vector<instruction_ref> args)
Shucai Xiao's avatar
Shucai Xiao committed
2227
2228
2229
2230
    {
        migraphx::shape input_shape = args[0]->get_shape();
        std::size_t hidden_size     = args[2]->get_shape().lens()[2];

2231
        if(contains(info.attributes, "hidden_size"))
Shucai Xiao's avatar
Shucai Xiao committed
2232
        {
2233
            std::size_t hidden_size_att = parse_value(info.attributes.at("hidden_size")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
2234
2235
2236
2237
2238
2239
2240
2241
            if(hidden_size != hidden_size_att)
            {
                MIGRAPHX_THROW("LSTM: hidden size mismatch in input and attribute");
            }
        }

        // Handling of direction to be added later
        std::string direction{"forward"};
2242
        if(contains(info.attributes, "direction"))
Shucai Xiao's avatar
Shucai Xiao committed
2243
        {
2244
            direction = info.attributes.at("direction").s();
Shucai Xiao's avatar
Shucai Xiao committed
2245
2246
        }

Shucai Xiao's avatar
Shucai Xiao committed
2247
        op::rnn_direction dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
2248
2249
        if(direction == "bidirectional")
        {
Shucai Xiao's avatar
Shucai Xiao committed
2250
            dirct = op::rnn_direction::bidirectional;
Shucai Xiao's avatar
Shucai Xiao committed
2251
2252
2253
        }
        else if(direction == "reverse")
        {
Shucai Xiao's avatar
Shucai Xiao committed
2254
            dirct = op::rnn_direction::reverse;
Shucai Xiao's avatar
Shucai Xiao committed
2255
        }
Shucai Xiao's avatar
Shucai Xiao committed
2256
        else if(direction == "forward")
Shucai Xiao's avatar
Shucai Xiao committed
2257
        {
Shucai Xiao's avatar
Shucai Xiao committed
2258
            dirct = op::rnn_direction::forward;
Shucai Xiao's avatar
Shucai Xiao committed
2259
2260
2261
2262
2263
2264
        }
        else
        {
            MIGRAPHX_THROW("LSTM: incorrect direction attribute");
        }

2265
        std::vector<std::string> vec_names = {"sigmoid", "tanh", "tanh"};
2266
        if(contains(info.attributes, "activations"))
Shucai Xiao's avatar
Shucai Xiao committed
2267
        {
2268
            auto names = info.attributes.at("activations").strings();
Shucai Xiao's avatar
Shucai Xiao committed
2269
2270
            vec_names.clear();
            vec_names.resize(names.size());
Shucai Xiao's avatar
Shucai Xiao committed
2271
2272
2273
            std::transform(names.begin(), names.end(), vec_names.begin(), [](auto name) {
                return to_lower(name);
            });
Shucai Xiao's avatar
Shucai Xiao committed
2274
2275
        }

Shucai Xiao's avatar
Shucai Xiao committed
2276
        lstm_actv_functions(dirct, vec_names);
Shucai Xiao's avatar
Shucai Xiao committed
2277

2278
2279
2280
        auto name_it = std::find_if(vec_names.begin(), vec_names.end(), [&](auto& name) {
            return (map_actv_funcs.count(name) == 0);
        });
Shucai Xiao's avatar
Shucai Xiao committed
2281
        if(name_it != vec_names.end())
2282
2283
2284
        {
            MIGRAPHX_THROW("LSTM: activation function " + std::string(*name_it) + " not supported");
        }
Shucai Xiao's avatar
Shucai Xiao committed
2285
2286

        std::vector<operation> vec_actv_funcs(vec_names.size());
Paul's avatar
Paul committed
2287
2288
2289
2290
        std::transform(vec_names.begin(),
                       vec_names.end(),
                       vec_actv_funcs.begin(),
                       [&](const auto& name) { return map_actv_funcs[name]; });
Shucai Xiao's avatar
Shucai Xiao committed
2291
2292

        float clip = 0.0;
2293
        if(contains(info.attributes, "clip"))
Shucai Xiao's avatar
Shucai Xiao committed
2294
        {
2295
            clip = parse_value(info.attributes.at("clip")).at<float>();
Shucai Xiao's avatar
Shucai Xiao committed
2296
2297
2298
        }

        int input_forget = 0;
2299
        if(contains(info.attributes, "input_forget"))
Shucai Xiao's avatar
Shucai Xiao committed
2300
        {
2301
            input_forget = parse_value(info.attributes.at("input_forget")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
2302
2303
2304
2305
2306
        }

        // append undefined opeator to make 6 arguments
        if(args.size() < 8)
        {
2307
            auto ins = mm->add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
2308
            args.insert(args.end(), 8 - args.size(), ins);
Shucai Xiao's avatar
Shucai Xiao committed
2309
2310
2311
        }

        // first output for concatenation of hidden states
2312
        auto hidden_states = mm->add_instruction(
Shucai Xiao's avatar
Shucai Xiao committed
2313
            op::lstm{hidden_size, vec_actv_funcs, dirct, clip, input_forget}, std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
2314

2315
        auto last_output = mm->add_instruction(op::rnn_last_hs_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
2316
2317

        // third output for last cell output
2318
        auto last_cell_output = mm->add_instruction(op::rnn_last_cell_output{}, hidden_states);
Shucai Xiao's avatar
Shucai Xiao committed
2319
2320
2321

        return {hidden_states, last_output, last_cell_output};
    }
2322

2323
2324
2325
    instruction_ref parse_reduce_oper(const std::string&,
                                      const std::string& op_name,
                                      node_info info,
2326
                                      std::vector<instruction_ref> args) const
Shucai Xiao's avatar
Shucai Xiao committed
2327
2328
2329
2330
    {
        std::size_t n_dim = args.front()->get_shape().lens().size();

        // default to reduce over all dimensions
2331
        std::vector<int64_t> axes(n_dim);
Shucai Xiao's avatar
Shucai Xiao committed
2332
        std::iota(axes.begin(), axes.end(), 0);
2333
        if(contains(info.attributes, "axes"))
Shucai Xiao's avatar
Shucai Xiao committed
2334
2335
        {
            axes.clear();
2336
            auto&& attr_axes = info.attributes["axes"].ints();
2337
            axes             = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
Shucai Xiao's avatar
Shucai Xiao committed
2338
2339
2340
        }

        int keep_dims = 1;
2341
        if(contains(info.attributes, "keepdims"))
Shucai Xiao's avatar
Shucai Xiao committed
2342
        {
2343
            keep_dims = parse_value(info.attributes.at("keepdims")).at<int>();
Shucai Xiao's avatar
Shucai Xiao committed
2344
2345
2346
2347
        }

        if(keep_dims == 1)
        {
2348
            return mm->add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
Shucai Xiao's avatar
Shucai Xiao committed
2349
2350
2351
        }
        else
        {
2352
2353
            auto ins = mm->add_instruction(make_op(op_name, {{"axes", axes}}), std::move(args));
            return mm->add_instruction(op::squeeze{axes}, ins);
2354
2355
        }
    }
2356

Shucai Xiao's avatar
Shucai Xiao committed
2357
    instruction_ref
2358
    parse_reduce_l1(const std::string&, node_info info, std::vector<instruction_ref> args) const
Shucai Xiao's avatar
Shucai Xiao committed
2359
    {
2360
        auto abs_ins = mm->add_instruction(make_op("abs"), args[0]);
2361
        return parse_reduce_oper({}, "reduce_sum", std::move(info), {abs_ins});
Shucai Xiao's avatar
Shucai Xiao committed
2362
2363
2364
    }

    instruction_ref
2365
    parse_reduce_l2(const std::string&, node_info info, std::vector<instruction_ref> args) const
Shucai Xiao's avatar
Shucai Xiao committed
2366
    {
2367
        auto square_ins = mm->add_instruction(make_op("mul"), args[0], args[0]);
2368
        auto sum_ins    = parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
2369
        return mm->add_instruction(make_op("sqrt"), sum_ins);
Shucai Xiao's avatar
Shucai Xiao committed
2370
2371
    }

2372
2373
2374
    instruction_ref parse_reduce_log_sum(const std::string&,
                                         node_info info,
                                         std::vector<instruction_ref> args) const
Shucai Xiao's avatar
Shucai Xiao committed
2375
    {
2376
        auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), std::move(args));
2377
        return mm->add_instruction(make_op("log"), sum_ins);
Shucai Xiao's avatar
Shucai Xiao committed
2378
2379
    }

2380
2381
2382
    instruction_ref parse_reduce_log_sum_exp(const std::string&,
                                             node_info info,
                                             std::vector<instruction_ref> args) const
Shucai Xiao's avatar
Shucai Xiao committed
2383
    {
2384
        auto exp_ins = mm->add_instruction(make_op("exp"), args[0]);
2385
        auto sum_ins = parse_reduce_oper({}, "reduce_sum", std::move(info), {exp_ins});
2386
        return mm->add_instruction(make_op("log"), sum_ins);
Shucai Xiao's avatar
Shucai Xiao committed
2387
2388
    }

2389
2390
2391
    instruction_ref parse_reduce_sum_square(const std::string&,
                                            node_info info,
                                            std::vector<instruction_ref> args) const
Shucai Xiao's avatar
Shucai Xiao committed
2392
    {
2393
        auto square_ins = mm->add_instruction(make_op("mul"), args[0], args[0]);
2394
        return parse_reduce_oper({}, "reduce_sum", std::move(info), {square_ins});
Shucai Xiao's avatar
Shucai Xiao committed
2395
2396
    }

Shucai Xiao's avatar
Shucai Xiao committed
2397
    instruction_ref
2398
    parse_cast(const std::string&, node_info info, std::vector<instruction_ref> args) const
2399
    {
2400
        if(!contains(info.attributes, "to"))
2401
2402
2403
2404
        {
            MIGRAPHX_THROW("PARSE_CAST: missing to type attribute!");
        }

2405
        int to_type        = parse_value(info.attributes.at("to")).at<int>();
2406
        shape::type_t type = get_type(to_type);
2407
        return mm->add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
2408
    }
Shucai Xiao's avatar
Shucai Xiao committed
2409

2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
    std::vector<instruction_ref>
    parse_split(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        if(contains(info.attributes, "axis"))
        {
            axis = parse_value(info.attributes.at("axis")).at<int>();
        }

        auto lens      = args[0]->get_shape().lens();
        int64_t n_rank = static_cast<int64_t>(lens.size());
        if((axis < -n_rank) || (axis >= n_rank))
        {
            MIGRAPHX_THROW("PARSE_SPLIT: axis attribute out of rank!");
        }
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;

        std::vector<int64_t> vec_splits;
        if(contains(info.attributes, "split"))
        {
            literal s = parse_value(info.attributes.at("split"));
            s.visit([&](auto v) { vec_splits.assign(v.begin(), v.end()); });

            if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
               static_cast<int64_t>(lens[tuned_axis]))
            {
                MIGRAPHX_THROW("PARSE_SPLIT: sum of split attribute unequal to dim size of axis!");
            }
        }
        // no split attribute, input is equally divided
        else
        {
            if((lens[tuned_axis] % info.num_outputs) != 0)
            {
                MIGRAPHX_THROW("PARSE_SPLIT: input cannot be equally divided into " +
                               to_string(info.num_outputs) + " splits!");
            }
            auto dl = lens[tuned_axis] / info.num_outputs;
            vec_splits.resize(info.num_outputs, dl);
        }

        std::vector<instruction_ref> ret_ins;
        int64_t start = 0;
        for(auto sl : vec_splits)
        {
            ret_ins.push_back(
2456
                mm->add_instruction(op::slice{{axis}, {start}, {start + sl}}, args[0]));
2457
2458
2459
2460
2461
2462
            start += sl;
        }

        return ret_ins;
    }

kahmed10's avatar
kahmed10 committed
2463
2464
2465
2466
    instruction_ref
    parse_onehot(const std::string&, node_info info, std::vector<instruction_ref> args)
    {
        migraphx::argument depth_arg = args[1]->eval();
Shucai Xiao's avatar
Shucai Xiao committed
2467
        check_arg_empty(depth_arg, "PARSE_ONEHOT: depth - dynamic shape not supported");
kahmed10's avatar
kahmed10 committed
2468
2469
2470
        size_t depth = depth_arg.at<size_t>();

        int64_t axis = -1;
Shucai Xiao's avatar
Shucai Xiao committed
2471
2472
2473
2474
        if(contains(info.attributes, "axis"))
        {
            axis = info.attributes.at("axis").i();
        }
kahmed10's avatar
kahmed10 committed
2475

Shucai Xiao's avatar
Shucai Xiao committed
2476
        std::vector<float> depth_input(depth * depth, 0.0f);
kahmed10's avatar
kahmed10 committed
2477
2478
        for(int i = 0; i < depth; i++)
        {
Shucai Xiao's avatar
Shucai Xiao committed
2479
            depth_input[depth * i + i] = 1.0f;
kahmed10's avatar
kahmed10 committed
2480
2481
        }

Shucai Xiao's avatar
Shucai Xiao committed
2482
2483
        auto type = args[2]->get_shape().type();
        shape s{type, {depth, depth}};
2484
2485
        auto l_val      = mm->add_literal({s, depth_input});
        auto gather_out = mm->add_instruction(op::gather{0}, {l_val, args[0]});
Shucai Xiao's avatar
Shucai Xiao committed
2486
2487
2488
2489

        // Finally, we need a transpose to move the inner most dim to the axis dim
        int n_rank = gather_out->get_shape().lens().size();
        if(axis < -n_rank or axis >= n_rank)
kahmed10's avatar
kahmed10 committed
2490
        {
Shucai Xiao's avatar
Shucai Xiao committed
2491
            MIGRAPHX_THROW("PARSE_ONEHOT: axis out of range");
kahmed10's avatar
kahmed10 committed
2492
        }
Shucai Xiao's avatar
Shucai Xiao committed
2493
2494
2495
2496
        int64_t tuned_axis = (axis < 0) ? axis + n_rank : axis;
        std::vector<int64_t> perm(n_rank - 1);
        std::iota(perm.begin(), perm.end(), 0);
        perm.insert(perm.begin() + tuned_axis, n_rank - 1);
2497
        auto tr_out = mm->add_instruction(op::transpose{perm}, gather_out);
Shucai Xiao's avatar
Shucai Xiao committed
2498
2499
        auto lens   = tr_out->get_shape().lens();

2500
2501
2502
2503
2504
2505
2506
        auto off_val       = mm->add_instruction(op::slice{{0}, {0}, {1}}, args[2]);
        auto on_val        = mm->add_instruction(op::slice{{0}, {1}, {2}}, args[2]);
        auto diff          = mm->add_instruction(make_op("sub"), on_val, off_val);
        auto unsq_off_val  = mm->add_instruction(op::multibroadcast{lens}, off_val);
        auto unsq_diff_val = mm->add_instruction(op::multibroadcast{lens}, diff);
        auto l_mul         = mm->add_instruction(make_op("mul"), tr_out, unsq_diff_val);
        return mm->add_instruction(make_op("add"), l_mul, unsq_off_val);
kahmed10's avatar
kahmed10 committed
2507
2508
    }

kahmed10's avatar
kahmed10 committed
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
    instruction_ref
    parse_tile(const std::string&, const node_info&, std::vector<instruction_ref> args)
    {
        migraphx::argument arg_s = args[1]->eval();
        check_arg_empty(arg_s, "PARSE_TILE: dynamic shape is not supported");
        std::vector<std::int64_t> repeats;
        arg_s.visit([&](auto input) { repeats.assign(input.begin(), input.end()); });

        auto l0 = args[0];
        for(int i = 0; i < repeats.size(); i++)
        {
            auto l1 = l0;
            for(int j = 1; j < repeats[i]; j++)
            {
2523
                l0 = mm->add_instruction(op::concat{i}, l0, l1);
kahmed10's avatar
kahmed10 committed
2524
2525
2526
2527
2528
            }
        }
        return l0;
    }

kahmed10's avatar
kahmed10 committed
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
    instruction_ref
    parse_range(const std::string&, const node_info&, std::vector<instruction_ref> args)
    {

        auto start_arg = args[0]->eval();
        check_arg_empty(start_arg, "PARSE_RANGE: start arg dynamic shape is not supported");
        auto limit_arg = args[1]->eval();
        check_arg_empty(limit_arg, "PARSE_RANGE: limit arg dynamic shape is not supported");
        auto delta_arg = args[2]->eval();
        check_arg_empty(delta_arg, "PARSE_RANGE: delta arg dynamic shape is not supported");

        assert(args[0]->get_shape().elements() == 1 and args[1]->get_shape().elements() == 1 and
               args[2]->get_shape().elements() == 1);

        instruction_ref l0;

        visit_all(start_arg, limit_arg, delta_arg)([&](auto start, auto limit, auto delta) {
            auto start_val = start.front();
            auto limit_val = limit.front();
            auto delta_val = delta.front();

            size_t num_elements = static_cast<size_t>(
                ceil(static_cast<double>(limit_val - start_val) / static_cast<double>(delta_val)));

            assert(num_elements > 0);

            using type = decltype(start_val);

            std::vector<type> range_vals(num_elements);

            std::generate(range_vals.begin(), range_vals.end(), [&]() {
                auto result = start_val;
                start_val += delta_val;
                return result;
            });

2565
            l0 = mm->add_literal({shape{args[0]->get_shape().type(), {num_elements}}, range_vals});
kahmed10's avatar
kahmed10 committed
2566
2567
2568
2569
        });
        return l0;
    }

2570
2571
2572
2573
2574
2575
2576
    enum class reduce_mode_t
    {
        sum  = 0,
        mean = 1,
        max  = 2
    };

2577
2578
    instruction_ref parse_embedding_bag(const node_info& info,
                                        std::vector<instruction_ref> args) const
2579
2580
2581
2582
2583
2584
2585
2586
2587
    {
        if(args[2]->get_shape().elements() != 1)
            MIGRAPHX_THROW("PARSE_EMBEDDING_BAG: MIGraphX only supports offsets of size 1");
        reduce_mode_t reduce_mode = reduce_mode_t::sum;
        if(contains(info.attributes, "mode"))
        {
            reduce_mode = static_cast<reduce_mode_t>(info.attributes.at("mode").i());
        }

2588
        auto l0 = mm->add_instruction(op::gather{}, args[0], args[1]);
2589
2590
        switch(reduce_mode)
        {
2591
        case reduce_mode_t::sum:
2592
            l0 = mm->add_instruction(make_op("reduce_sum", {{"axes", {0}}}), l0);
2593
2594
            break;
        case reduce_mode_t::mean:
2595
            l0 = mm->add_instruction(make_op("reduce_mean", {{"axes", {0}}}), l0);
2596
2597
            break;
        case reduce_mode_t::max:
2598
            l0 = mm->add_instruction(make_op("reduce_max", {{"axes", {0}}}), l0);
2599
            break;
2600
2601
2602
2603
2604
        }
        return l0;
    }

    instruction_ref
2605
    parse_aten(const std::string&, const node_info& info, std::vector<instruction_ref> args) const
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
    {
        if(contains(info.attributes, "operator"))
        {
            auto op_name = info.attributes.at("operator").s();
            if(op_name.find("embedding_bag") != std::string::npos)
            {
                return parse_embedding_bag(info, std::move(args));
            }
        }
        MIGRAPHX_THROW("PARSE_ATEN: unsupported custom operator");
    }

2618
    std::vector<instruction_ref>
2619
    parse_dropout(const std::string&, const node_info&, std::vector<instruction_ref> args) const
2620
    {
2621
        auto out = mm->add_instruction(make_op("identity"), args[0]);
2622
2623
2624
        auto s   = args[0]->get_shape();
        std::vector<int8_t> vec(s.elements(), 1);
        shape mask_s{shape::bool_type, s.lens()};
2625
        auto mask = mm->add_literal(literal(mask_s, vec));
2626
2627
2628
2629

        return {out, mask};
    }

Shucai Xiao's avatar
Shucai Xiao committed
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
    template <class T>
    std::vector<std::size_t> nonzero_indices(const std::vector<T>& data)
    {
        std::vector<std::size_t> indices;
        for(std::size_t i = 0; i < data.size(); ++i)
        {
            if(!float_equal(data[i], 0))
                indices.push_back(i);
        }

        return indices;
    }

    instruction_ref
    parse_nonzero(const std::string&, const node_info&, std::vector<instruction_ref> args)
    {
        migraphx::argument data_arg = args.back()->eval();
        check_arg_empty(data_arg, "PARSE_NONZERO: cannot support non-constant input!");

        std::vector<std::size_t> indices;
        data_arg.visit([&](auto val) {
            using val_type = std::remove_cv_t<typename decltype(val)::value_type>;
            std::vector<val_type> vec_data;
            vec_data.assign(val.begin(), val.end());
            indices = this->nonzero_indices(vec_data);
        });

        shape in_s = args[0]->get_shape();
        shape out_s{shape::int64_type, {in_s.lens().size(), indices.size()}};

        std::vector<int64_t> out_data(out_s.elements());
        for(std::size_t i = 0; i < indices.size(); ++i)
        {
            auto idx = in_s.multi(indices[i]);
            for(std::size_t j = 0; j < in_s.lens().size(); ++j)
            {
                out_data[out_s.index({j, i})] = idx[j];
            }
        }

2670
        return mm->add_literal(literal(out_s, out_data));
Shucai Xiao's avatar
Shucai Xiao committed
2671
2672
    }

2673
2674
2675
2676
    instruction_ref parse_compare_op(const std::string&,
                                     const std::string& op_name,
                                     const node_info&,
                                     std::vector<instruction_ref> args)
2677
    {
2678
        auto l = add_broadcastable_binary_op(args[0], args[1], op_name);
2679
2680
        if(l->get_shape().type() != shape::bool_type)
        {
2681
            l = mm->add_instruction(make_op("convert", {{"target_type", shape::bool_type}}), l);
2682
2683
2684
2685
        }
        return l;
    }

Shucai Xiao's avatar
Shucai Xiao committed
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
    instruction_ref
    parse_upsample(const std::string&, const node_info& info, std::vector<instruction_ref> args)
    {
        if(contains(info.attributes, "mode"))
        {
            auto mode = info.attributes.at("mode").s();
            if(mode != "nearest")
            {
                MIGRAPHX_THROW("PARSE_UPSAMPLE: only nearest mode is supported!");
            }
        }

        auto arg_scale = args[1]->eval();
        check_arg_empty(arg_scale, "PARSE_UPSAMPLE: only constant scale is supported!");
        std::vector<float> vec_scale;
        arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });

        auto in_s    = args[0]->get_shape();
        auto in_lens = in_s.lens();
        if(in_lens.size() != vec_scale.size())
        {
            MIGRAPHX_THROW("PARSE_UPSAMPLE: ranks of input and scale are different!");
        }

        std::vector<std::size_t> out_lens(in_lens.size());
        std::transform(in_lens.begin(),
                       in_lens.end(),
                       vec_scale.begin(),
                       out_lens.begin(),
                       [&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });

        std::vector<float> idx_scale(in_lens.size());
        std::transform(
            out_lens.begin(),
            out_lens.end(),
            in_lens.begin(),
            idx_scale.begin(),
            [](auto od, auto id) { return (od == id) ? 1.0f : (id - 1.0f) / (od - 1.0f); });

        shape out_s{in_s.type(), out_lens};
        std::vector<int> ind(out_s.elements());

        // map out_idx to in_idx
        shape_for_each(out_s, [&](auto idx) {
            auto in_idx = idx;
            std::transform(idx.begin(),
                           idx.end(),
                           idx_scale.begin(),
                           in_idx.begin(),
                           // nearest mode
                           [](auto index, auto scale) {
                               return static_cast<std::size_t>(std::round(index * scale));
                           });

            ind[out_s.index(idx)] = static_cast<int64_t>(in_s.index(in_idx));
        });

        // reshape input to one-dimension
        std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
        shape ind_s{shape::int32_type, out_lens};
2746
2747
2748
        auto rsp     = mm->add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
        auto ins_ind = mm->add_literal(literal(ind_s, ind));
        return mm->add_instruction(make_op("gather", {{"axis", 0}}), rsp, ins_ind);
Shucai Xiao's avatar
Shucai Xiao committed
2749
2750
    }

Shucai Xiao's avatar
Shucai Xiao committed
2751
2752
2753
2754
2755
2756
    instruction_ref
    parse_where(const std::string&, const node_info&, std::vector<instruction_ref> args)
    {
        auto type = args[1]->get_shape().type();
        // the operation of if cond == 1 select x; else select y,
        // is equivalent to cond * (x - y) + y
2757
        auto cond = mm->add_instruction(make_op("convert", {{"target_type", type}}), args[0]);
Shucai Xiao's avatar
Shucai Xiao committed
2758
2759
2760
2761
2762
        auto diff = add_broadcastable_binary_op(args[1], args[2], "sub");
        auto cd   = add_broadcastable_binary_op(diff, cond, "mul");
        return add_broadcastable_binary_op(cd, args[2], "add");
    }

2763
    void parse_from(std::istream& is, std::string name = "")
Paul's avatar
Paul committed
2764
    {
2765
2766
2767
2768
2769
        this->filename   = std::move(name);
        auto parent_path = fs::path(this->filename).parent_path();
        if(not parent_path.empty())
            this->path = parent_path;

Paul's avatar
Paul committed
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
        onnx::ModelProto model;
        if(model.ParseFromIstream(&is))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
Paul's avatar
Paul committed
2780
            MIGRAPHX_THROW("Failed reading onnx file.");
Paul's avatar
Paul committed
2781
2782
2783
        }
    }

Paul Fultz II's avatar
Paul Fultz II committed
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
    void parse_from(const void* data, std::size_t size)
    {
        onnx::ModelProto model;
        if(model.ParseFromArray(data, size))
        {
            if(model.has_graph())
            {
                this->parse_graph(model.graph());
            }
        }
        else
        {
            MIGRAPHX_THROW("Failed reading onnx file.");
        }
    }

Paul's avatar
Paul committed
2800
2801
    void parse_graph(const onnx::GraphProto& graph)
    {
2802
        for(auto&& f : graph.initializer())
2803
        {
2804
            instructions[f.name()] = mm->add_literal(parse_tensor(f));
2805
        }
2806

Paul's avatar
Paul committed
2807
2808
2809
        for(auto&& input : graph.input())
        {
            const std::string& name = input.name();
2810
2811
            // input not in initializer_data, so it is a real input
            if(!contains(instructions, name))
2812
            {
2813
2814
2815
2816
2817
2818
2819
                std::vector<std::size_t> dims;
                if(map_input_dims.count(name) > 0)
                {
                    dims = map_input_dims.at(name);
                }

                shape s            = parse_type(input.type(), dims);
2820
                instructions[name] = mm->add_parameter(name, s);
2821
            }
Paul's avatar
Paul committed
2822
        }
2823
2824

        for(auto&& node : graph.node())
Paul's avatar
Paul committed
2825
        {
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
            std::vector<instruction_ref> args;
            for(auto&& input : node.input())
            {
                if(input.empty())
                {
                    this->parse_undefined(input);
                }
                if(instructions.count(input) == 0)
                {
                    MIGRAPHX_THROW("PARSE_GRAPH: invalid onnx file. Input \"" + input +
                                   "\" is unavailable due to unordered nodes!");
                }
                args.push_back(instructions.at(input));
            }

            std::vector<instruction_ref> result;
            std::size_t output_num = static_cast<std::size_t>(node.output().size());
            if(ops.count(node.op_type()) == 0)
            {
2845
                if(skip_unknown_operators)
2846
                    result.push_back(mm->add_instruction(op::unknown{node.op_type()}, args));
2847
2848
                else
                    MIGRAPHX_THROW("Unknown operator: " + node.op_type());
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
            }
            else
            {
                result = ops[node.op_type()]({get_attributes(node), output_num}, args);
            }

            output_num = std::min<std::size_t>(output_num, result.size());
            std::transform(node.output().begin(),
                           node.output().begin() + output_num,
                           result.begin(),
                           std::inserter(instructions, instructions.end()),
                           [](auto&& x, auto&& y) { return std::make_pair(x, y); });
Paul's avatar
Paul committed
2861
        }
Shucai Xiao's avatar
Shucai Xiao committed
2862

2863
        // Find instructions corresponding to the output
Shucai Xiao's avatar
Shucai Xiao committed
2864
        auto prog_output = graph.output();
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
        std::vector<std::string> all_output_names;
        std::vector<std::string> prog_output_names;
        std::transform(prog_output.begin(),
                       prog_output.end(),
                       std::back_inserter(all_output_names),
                       [](auto& node) { return node.name(); });
        std::copy_if(
            all_output_names.begin(),
            all_output_names.end(),
            std::back_inserter(prog_output_names),
            [&](const auto& name) { return !(name.empty() or instructions.count(name) == 0); });

        std::vector<instruction_ref> output_ins;
        std::transform(prog_output_names.begin(),
                       prog_output_names.end(),
                       std::back_inserter(output_ins),
                       [&](const auto& name) { return instructions[name]; });

        // add the return instuction
2884
        mm->add_return(output_ins);
Paul's avatar
Paul committed
2885
2886
    }

Shucai Xiao's avatar
Shucai Xiao committed
2887
    void parse_undefined(const std::string& name)
2888
    {
Shucai Xiao's avatar
Shucai Xiao committed
2889
2890
        if(!contains(instructions, name))
        {
2891
            auto ins           = mm->add_instruction(op::undefined{});
Shucai Xiao's avatar
Shucai Xiao committed
2892
2893
            instructions[name] = ins;
        }
2894
2895
    }

Paul's avatar
Paul committed
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
    static attribute_map get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
        for(auto&& attr : node.attribute())
        {
            result[attr.name()] = attr;
        }
        return result;
    }

Shucai Xiao's avatar
Shucai Xiao committed
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
    static shape::type_t get_type(int dtype)
    {
        switch(dtype)
        {
        case 1: return shape::float_type;
        case 2: return shape::uint8_type;
        case 3: return shape::int8_type;
        case 4: return shape::uint16_type;
        case 5: return shape::int16_type;
        case 6: return shape::int32_type;
        case 7: return shape::int64_type;
        case 9: return shape::bool_type;
        case 10: return shape::half_type;
        case 11: return shape::double_type;
        case 12: return shape::uint32_type;
        case 13: return shape::uint64_type;
        default:
        {
            MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
        }
        }
    }

Paul's avatar
Paul committed
2929
2930
2931
2932
2933
2934
2935
    template <class T>
    static literal from_repeated(shape::type_t t, const T& r)
    {
        std::size_t size = r.size();
        return literal{{t, {size}}, r.begin(), r.end()};
    }

2936
    literal parse_value(const onnx::AttributeProto& attr) const
Paul's avatar
Paul committed
2937
2938
2939
2940
2941
2942
    {
        switch(attr.type())
        {
        case onnx::AttributeProto::FLOAT: return literal{attr.f()};
        case onnx::AttributeProto::INT: return literal{attr.i()};
        case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
Paul's avatar
Paul committed
2943
        case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
Paul's avatar
Paul committed
2944
        case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
Paul's avatar
Paul committed
2945
2946
2947
2948
2949
        case onnx::AttributeProto::UNDEFINED:
        case onnx::AttributeProto::GRAPH:
        case onnx::AttributeProto::STRING:
        case onnx::AttributeProto::STRINGS:
        case onnx::AttributeProto::TENSORS:
2950
2951
        case onnx::AttributeProto::SPARSE_TENSOR:
        case onnx::AttributeProto::SPARSE_TENSORS:
Paul's avatar
Paul committed
2952
2953
        case onnx::AttributeProto::GRAPHS: return {};
        }
Shucai Xiao's avatar
Shucai Xiao committed
2954
        MIGRAPHX_THROW("PARSE_VALUE: Invalid attribute type " + std::to_string(attr.type()));
Paul's avatar
Paul committed
2955
2956
    }

2957
    literal parse_tensor(const onnx::TensorProto& t) const
Paul's avatar
Paul committed
2958
2959
    {
        std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
2960
2961
2962
2963
2964
2965
2966
2967
        if(not t.external_data().empty())
        {
            const std::string& data_file = t.external_data().at(0).value();
            auto raw_buffer              = read_buffer(path + "/" + data_file);
            std::string s(raw_buffer.begin(), raw_buffer.end());
            auto type = get_type(t.data_type());
            return create_literal(type, dims, s.data());
        }
2968
2969
        if(t.has_raw_data())
        {
wsttiger's avatar
wsttiger committed
2970
            const std::string& s = t.raw_data();
Shucai Xiao's avatar
Shucai Xiao committed
2971
2972
            auto type            = get_type(t.data_type());
            return create_literal(type, dims, s.data());
2973
        }
Shucai Xiao's avatar
Shucai Xiao committed
2974

Paul's avatar
Paul committed
2975
2976
        switch(t.data_type())
        {
Shucai Xiao's avatar
Shucai Xiao committed
2977
2978
2979
2980
        case onnx::TensorProto::BOOL: return create_literal(shape::bool_type, dims, t.int32_data());
        case onnx::TensorProto::INT8: return create_literal(shape::int8_type, dims, t.int32_data());
        case onnx::TensorProto::UINT8:
            return create_literal(shape::uint8_type, dims, t.int32_data());
Paul's avatar
Paul committed
2981
        case onnx::TensorProto::INT16:
Shucai Xiao's avatar
Shucai Xiao committed
2982
2983
2984
            return create_literal(shape::int16_type, dims, t.int32_data());
        case onnx::TensorProto::UINT16:
            return create_literal(shape::uint16_type, dims, t.int32_data());
Paul's avatar
Paul committed
2985
        case onnx::TensorProto::INT32:
Khalique's avatar
Khalique committed
2986
            return create_literal(shape::int32_type, dims, t.int32_data());
Shucai Xiao's avatar
Shucai Xiao committed
2987
2988
        case onnx::TensorProto::UINT32:
            return create_literal(shape::uint32_type, dims, t.uint64_data());
Paul's avatar
Paul committed
2989
        case onnx::TensorProto::INT64:
Khalique's avatar
Khalique committed
2990
            return create_literal(shape::int64_type, dims, t.int64_data());
Shucai Xiao's avatar
Shucai Xiao committed
2991
2992
        case onnx::TensorProto::UINT64:
            return create_literal(shape::uint64_type, dims, t.uint64_data());
Paul's avatar
Paul committed
2993
        case onnx::TensorProto::FLOAT16:
Khalique's avatar
Khalique committed
2994
        {
Khalique's avatar
Khalique committed
2995
            std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
2996
            std::vector<half> data_half;
Khalique's avatar
Khalique committed
2997
2998
2999
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
3000
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
3001
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
3002
        }
Shucai Xiao's avatar
Shucai Xiao committed
3003
3004
3005
3006
        case onnx::TensorProto::DOUBLE:
            return create_literal(shape::double_type, dims, t.double_data());
        case onnx::TensorProto::FLOAT:
            return create_literal(shape::float_type, dims, t.float_data());
Paul's avatar
Paul committed
3007
3008
3009
        case onnx::TensorProto::UNDEFINED:
        case onnx::TensorProto::STRING:
        case onnx::TensorProto::COMPLEX64:
Paul's avatar
Paul committed
3010
3011
        case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
        }
Shucai Xiao's avatar
Shucai Xiao committed
3012
        MIGRAPHX_THROW("PARSE_TENSOR: Invalid tensor type");
Paul's avatar
Paul committed
3013
3014
    }

Khalique's avatar
Khalique committed
3015
    static literal
3016
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
3017
    {
Khalique's avatar
Khalique committed
3018
        // in case of scalar constants in onnx file, use dims=1 to fill initializer data
3019
        if(dims.empty())
3020
            return literal{{shape_type}, data};
3021
3022
3023
        return literal{{shape_type, dims}, data};
    }

3024
    template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
Khalique's avatar
Khalique committed
3025
    static literal create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
3026
3027
    {
        if(dims.empty())
3028
            return literal{{shape_type}, data.begin(), data.end()};
3029
        return literal{{shape_type, dims}, data.begin(), data.end()};
3030
3031
    }

3032
    shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const
Paul's avatar
Paul committed
3033
    {
Shucai Xiao's avatar
Shucai Xiao committed
3034
        shape::type_t shape_type = get_type(t.tensor_type().elem_type());
3035
3036
3037
3038
3039
        if(!input_dims.empty())
        {
            return {shape_type, input_dims};
        }

Paul's avatar
Paul committed
3040
        std::vector<std::size_t> dims;
Paul's avatar
Paul committed
3041
        auto&& tensor_dims = t.tensor_type().shape().dim();
3042
3043
3044
        std::transform(tensor_dims.begin(),
                       tensor_dims.end(),
                       std::back_inserter(dims),
3045
3046
                       [&](auto&& d) -> std::size_t {
                           if(d.has_dim_value())
3047
                           {
3048
                               if(static_cast<int>(d.dim_value()) <= 0)
3049
3050
3051
                               {
                                   return default_dim_value;
                               }
3052
                               return d.dim_value();
3053
                           }
3054
3055
3056
3057
                           else
                           {
                               return default_dim_value;
                           }
3058
                       });
3059

3060
3061
3062
        if(dims.empty())
            return {shape_type};

Paul's avatar
Paul committed
3063
3064
        return {shape_type, dims};
    }
3065

Shucai Xiao's avatar
Shucai Xiao committed
3066
3067
    void check_arg_empty(const argument& arg, const std::string& msg)
    {
Shucai Xiao's avatar
Shucai Xiao committed
3068
        if(arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
3069
3070
3071
3072
        {
            MIGRAPHX_THROW(msg);
        }
    }
Paul's avatar
Paul committed
3073
3074
};

Paul Fultz II's avatar
Paul Fultz II committed
3075
template <class... Ts>
3076
program parse_onnx_from(const onnx_options& options, Ts&&... xs)
Paul's avatar
Paul committed
3077
3078
{
    onnx_parser parser;
3079
3080
3081
    parser.map_input_dims         = options.map_input_dims;
    parser.default_dim_value      = options.default_dim_value;
    parser.skip_unknown_operators = options.skip_unknown_operators;
3082

3083
    if(options.print_program_on_error)
Paul's avatar
Paul committed
3084
    {
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
        // Log the program when it can't be parsed
        try
        {
            parser.parse_from(std::forward<Ts>(xs)...);
        }
        catch(...)
        {
            std::cerr << parser.prog << std::endl;
            throw;
        }
Paul's avatar
Paul committed
3095
    }
3096
    else
Paul's avatar
Paul committed
3097
    {
3098
        parser.parse_from(std::forward<Ts>(xs)...);
Paul's avatar
Paul committed
3099
3100
3101
3102
    }
    return std::move(parser.prog);
}

3103
program parse_onnx(const std::string& name, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
3104
3105
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
3106
    return parse_onnx_from(options, input, name);
Paul Fultz II's avatar
Paul Fultz II committed
3107
3108
}

3109
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
3110
3111
3112
3113
{
    return parse_onnx_from(options, buffer.data(), buffer.size());
}

3114
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options)
Paul Fultz II's avatar
Paul Fultz II committed
3115
3116
3117
3118
{
    return parse_onnx_from(options, data, size);
}

Paul's avatar
Paul committed
3119
} // namespace MIGRAPHX_INLINE_NS
Paul's avatar
Paul committed
3120
} // namespace migraphx