"vscode:/vscode.git/clone" did not exist on "ec5a90ae263f8fa60389f313025dc74c6da2b232"
tf.cpp 55.4 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <array>
#include <utility>
#include <vector>

#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
20
21
#include <migraphx/make_op.hpp>

Khalique's avatar
Khalique committed
22
#include <migraphx/pad_calc.hpp>
Khalique's avatar
Khalique committed
23
24
25
26
27
28
29

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct tf_parser
{
    using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
Paul's avatar
Paul committed
30
    using node_map      = std::map<std::string, tensorflow::NodeDef>;
kahmed10's avatar
kahmed10 committed
31
32
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Khalique's avatar
Khalique committed
33

Khalique's avatar
Khalique committed
34
35
36
    node_map nodes;
    std::vector<tensorflow::NodeDef> input_nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
37
    program prog            = program();
38
    module* mm              = prog.get_main_module();
39
40
    bool is_nhwc            = true;
    unsigned int batch_size = 1;
Shucai Xiao's avatar
Shucai Xiao committed
41
42
    // Specified dims of inputs
    std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
Khalique's avatar
Khalique committed
43
44
45

    std::unordered_map<std::string, op_func> ops;

Paul's avatar
Paul committed
46
    bool should_transpose(instruction_ref ins) const
Paul's avatar
Paul committed
47
48
49
50
    {
        return is_nhwc and ins->get_shape().lens().size() == 4;
    }

51
    instruction_ref to_nhwc(instruction_ref ins) const
Paul's avatar
Paul committed
52
    {
Paul's avatar
Paul committed
53
        if(should_transpose(ins))
54
            return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins);
Paul's avatar
Paul committed
55
56
57
        return ins;
    }

58
    instruction_ref to_nchw(instruction_ref ins) const
Paul's avatar
Paul committed
59
    {
Paul's avatar
Paul committed
60
        if(should_transpose(ins))
61
            return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins);
Paul's avatar
Paul committed
62
63
64
        return ins;
    }

65
    instruction_ref to_kcxy(instruction_ref ins) const
Paul's avatar
Paul committed
66
    {
67
        return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins);
Paul's avatar
Paul committed
68
69
    }

70
    instruction_ref make_contiguous(instruction_ref ins) const
Paul's avatar
Paul committed
71
    {
Paul's avatar
Paul committed
72
        if(ins->get_shape().standard())
Paul's avatar
Paul committed
73
74
            return ins;
        else
75
            return mm->add_instruction(make_op("contiguous"), ins);
Paul's avatar
Paul committed
76
77
78
79
80
    }

    std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
    {
        std::vector<instruction_ref> result(args.size());
Paul's avatar
Paul committed
81
        std::transform(
Paul's avatar
Paul committed
82
            args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
Paul's avatar
Paul committed
83
84
85
        return result;
    }

kahmed10's avatar
kahmed10 committed
86
87
88
89
90
91
92
93
    std::vector<instruction_ref> to_nhwc(const std::vector<instruction_ref>& args)
    {
        std::vector<instruction_ref> result(args.size());
        std::transform(
            args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nhwc(ins); });
        return result;
    }

Khalique's avatar
Khalique committed
94
    std::vector<size_t>
95
    parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
96
    {
97
98
99
        auto attrs = attributes.at(s).list().i();
        std::vector<size_t> axes;
        copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
Khalique's avatar
Khalique committed
100
        if(is_nhwc)
101
        {
Khalique's avatar
Khalique committed
102
            std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
Khalique's avatar
Khalique committed
103
                return parse_axis(axis, num_dims);
Khalique's avatar
Khalique committed
104
            });
105
106
107
108
        }
        return axes;
    }

Khalique's avatar
Khalique committed
109
    template <class T>
110
    std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
Khalique's avatar
Khalique committed
111
112
113
    {
        if(is_nhwc)
        {
114
            std::vector<T> new_axes;
Khalique's avatar
Khalique committed
115
116
117
            std::transform(axes.begin(),
                           axes.end(),
                           std::back_inserter(new_axes),
Khalique's avatar
Khalique committed
118
                           [&](size_t axis) { return parse_axis(axis, num_dims); });
119
            return new_axes;
Khalique's avatar
Khalique committed
120
        }
121
        return axes;
Khalique's avatar
Khalique committed
122
123
    }

Khalique's avatar
Khalique committed
124
125
126
    // tf stores certain attributes such as strides, dilations, as a 4D input.
    // The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
    // This helper function reorders the data to store for the respective operator member variables.
127
    template <class T>
128
    void reorder_data(std::vector<T>& prev_data) const
129
130
    {
        std::vector<T> new_data(prev_data.size());
131
        for(size_t i = 0; i < new_data.size(); i++)
132
        {
Khalique's avatar
Khalique committed
133
            auto new_idx         = parse_axis(i, new_data.size());
134
            new_data.at(new_idx) = prev_data.at(i);
135
        }
136
137
138
139
        prev_data = new_data;
    }

    template <class T>
140
    T parse_axis(const T& dim, const size_t num_dims) const
141
    {
Khalique's avatar
Khalique committed
142
        T new_dim = dim;
Khalique's avatar
Khalique committed
143
        if(is_nhwc and num_dims >= 4)
144
145
146
        {
            switch(dim)
            {
Khalique's avatar
Khalique committed
147
148
149
150
151
            case 0: new_dim = 0; break;
            case 1: new_dim = 2; break;
            case 2: new_dim = 3; break;
            case 3: new_dim = 1; break;
            default: break;
152
153
            }
        }
Khalique's avatar
Khalique committed
154
        return new_dim;
155
156
    }

157
158
159
160
161
162
163
    std::vector<int64_t> get_axes(size_t num_axes) const
    {
        std::vector<int64_t> axes(num_axes);
        std::iota(axes.begin(), axes.end(), 0);
        return axes;
    }

Khalique's avatar
Khalique committed
164
    std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
Khalique's avatar
Khalique committed
165
    {
Khalique's avatar
Khalique committed
166
        uint32_t bitwise_compare = 1;
Khalique's avatar
Khalique committed
167
168
169
170
171
172
173
174
175
176
177
178
        std::vector<int64_t> axes;
        for(size_t i = 0; i < num_axes; i++)
        {
            // the LSB corresponds to axis 0 when determining which axes to begin
            if(((mask >> i) & bitwise_compare) == 1)
                axes.push_back(1);
            else
                axes.push_back(0);
        }
        return axes;
    }

Khalique's avatar
Khalique committed
179
180
    tf_parser()
    {
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        add_generic_op("All", make_op("identity"));
        add_generic_op("Identity", make_op("identity"));
        add_generic_op("LessEqual", make_op("identity"));
        add_generic_op("Relu", make_op("relu"));
        add_generic_op("Rsqrt", make_op("rsqrt"));
        add_generic_op("Tanh", make_op("tanh"));
        add_generic_op("StopGradient", make_op("identity"));

        add_binary_op("Add", make_op("add"));
        add_binary_op("AddV2", make_op("add"));
        add_binary_op("Mul", make_op("mul"));
        add_binary_op("Pow", make_op("pow"));
        add_binary_op("SquaredDifference", make_op("sqdiff"));
        add_binary_op("Sub", make_op("sub"));
Khalique's avatar
Khalique committed
195

196
197
        add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false);
        add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false);
198
        add_mem_op("AvgPool", &tf_parser::parse_pooling);
Khalique's avatar
Khalique committed
199
        add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
Khalique's avatar
Khalique committed
200
        add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
201
        add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
Khalique's avatar
Khalique committed
202
        add_mem_op("Cast", &tf_parser::parse_cast, false);
Paul's avatar
Paul committed
203
        add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
Khalique's avatar
Khalique committed
204
        add_mem_op("Const", &tf_parser::parse_constant);
Paul's avatar
Paul committed
205
        add_mem_op("Conv2D", &tf_parser::parse_conv);
Paul's avatar
Paul committed
206
        add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
207
        add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
Khalique's avatar
Khalique committed
208
        add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
kahmed10's avatar
kahmed10 committed
209
        add_mem_op("FusedBatchNormV3", &tf_parser::parse_batchnorm);
Khalique's avatar
Khalique committed
210
        add_mem_op("GatherV2", &tf_parser::parse_gather, false);
Paul's avatar
Paul committed
211
        add_mem_op("MatMul", &tf_parser::parse_matmul, false);
212
        add_mem_op("MaxPool", &tf_parser::parse_pooling);
Khalique's avatar
Khalique committed
213
        add_mem_op("Mean", &tf_parser::parse_mean, false);
Khalique's avatar
Khalique committed
214
        add_mem_op("OneHot", &tf_parser::parse_onehot, false);
Paul's avatar
Paul committed
215
        add_mem_op("Pack", &tf_parser::parse_pack, false);
Paul's avatar
Paul committed
216
        add_mem_op("Pad", &tf_parser::parse_pad);
kahmed10's avatar
kahmed10 committed
217
        add_mem_op("Relu6", &tf_parser::parse_relu6);
Paul's avatar
Paul committed
218
        add_mem_op("Reshape", &tf_parser::parse_reshape, false);
219
        add_mem_op("Shape", &tf_parser::parse_shape, false);
Khalique's avatar
Khalique committed
220
        add_mem_op("Slice", &tf_parser::parse_slice, false);
kahmed10's avatar
kahmed10 committed
221
222
        add_mem_op("Split", &tf_parser::parse_split, false);
        add_mem_op("SplitV", &tf_parser::parse_split, false);
Khalique's avatar
Khalique committed
223
        add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
Paul's avatar
Paul committed
224
        add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
225
        add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
Khalique's avatar
Khalique committed
226
        add_mem_op("Transpose", &tf_parser::parse_transpose, false);
Khalique's avatar
Khalique committed
227
228
    }

229
    template <class F>
kahmed10's avatar
kahmed10 committed
230
    void add_op(const std::string& name, F f, bool transpose = true)
231
    {
Paul's avatar
Paul committed
232
        if(transpose)
Paul's avatar
Paul committed
233
        {
kahmed10's avatar
kahmed10 committed
234
235
236
237
238
239
            ops.emplace(
                name,
                op_func{
                    [=](const attribute_map& attributes, const std::vector<instruction_ref>& args) {
                        return std::vector<instruction_ref>{to_nhwc(f(attributes, to_nchw(args)))};
                    }});
Paul's avatar
Paul committed
240
241
242
        }
        else
        {
kahmed10's avatar
kahmed10 committed
243
244
245
246
247
            ops.emplace(name,
                        op_func{[=](const attribute_map& attributes,
                                    const std::vector<instruction_ref>& args) {
                            return std::vector<instruction_ref>{f(attributes, args)};
                        }});
Paul's avatar
Paul committed
248
        }
249
250
    }

Khalique's avatar
Khalique committed
251
    template <class F>
Paul's avatar
Paul committed
252
    void add_mem_op(std::string name, F f, bool transpose = true)
Khalique's avatar
Khalique committed
253
    {
Paul's avatar
Paul committed
254
255
256
257
258
        add_op(name,
               [=](auto&&... xs) {
                   return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
               },
               transpose);
Khalique's avatar
Khalique committed
259
260
261
262
263
    }

    template <class T>
    void add_binary_op(std::string name, T x)
    {
Paul's avatar
Paul committed
264
265
266
267
268
269
270
271
272
        add_op(name,
               [this, x](const attribute_map&, std::vector<instruction_ref> args) {
                   if(args.size() != 2)
                       MIGRAPHX_THROW("binary operators should have 2 operands");
                   // TODO
                   // if(contains(attributes, "data_format"))
                   // {
                   //     if(is_nhwc)
                   //     {
273
                   //         l0 = mm->add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
Paul's avatar
Paul committed
274
275
276
277
278
                   //     }
                   // }
                   return add_broadcastable_binary_op(args[0], args[1], x);
               },
               false);
Khalique's avatar
Khalique committed
279
280
281
282
283
    }

    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
284
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        {
            // Example:
            // s0 = (3,2,4,5) and s1 = (2,1,1)
            //
            // In this case we need to broadcast (:,1,1) portion of
            // s1 plus broadcast the 1st dimension of s1
            // giving output_lens = (3,2,4,5)
            //
            // Another example:
            // s0 = (3,2,1,5) and s1 = (2,7,5)
            // In this case we need to broadcast the (:,:,1:,:) axis
            // of s0 plus the 1st dimension of s1 giving
            // output_lens = (3,2,7,5)
            //
            // Get lengths for both arguments
300
301
            const std::vector<size_t>* s0 = &arg0->get_shape().lens();
            const std::vector<size_t>* s1 = &arg1->get_shape().lens();
Khalique's avatar
Khalique committed
302
303
304
305
306

            // Make sure s0 is the smaller size
            if(s0->size() > s1->size())
                std::swap(s0, s1);

307
            std::vector<size_t> output_lens(*s1);
Khalique's avatar
Khalique committed
308
309
310
311
312
313
314
            auto offset = s1->size() - s0->size();
            std::transform(s0->begin(),
                           s0->end(),
                           s1->begin() + offset,
                           output_lens.begin() + offset,
                           [](auto a, auto b) { return std::max(a, b); });

315
316
317
318
            auto l0 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
                                          arg0);
            auto l1 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
                                          arg1);
319
            return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
Khalique's avatar
Khalique committed
320
321
322
        }
        else
        {
323
            return to_nhwc(mm->add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
Khalique's avatar
Khalique committed
324
325
326
327
        }
    }

    template <class T>
Paul's avatar
Paul committed
328
    void add_generic_op(std::string name, T x, bool transpose = true)
Khalique's avatar
Khalique committed
329
    {
Paul's avatar
Paul committed
330
331
        add_op(name,
               [this, x](const attribute_map&, std::vector<instruction_ref> args) {
332
                   return mm->add_instruction(x, args);
Paul's avatar
Paul committed
333
334
               },
               transpose);
Khalique's avatar
Khalique committed
335
336
    }

337
338
339
340
341
342
    template <class Op>
    instruction_ref
    parse_arg_op(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        axis         = args[1]->eval().at<int64_t>();
343
        auto ins     = mm->add_instruction(Op{axis}, args.front());
344
        return mm->add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
345
346
    }

347
348
349
    instruction_ref parse_batchnorm(const std::string&,
                                    attribute_map attributes,
                                    std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
350
    {
Khalique's avatar
Khalique committed
351
352
353
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
Khalique's avatar
Khalique committed
354
355
356
357
358
        if(contains(attributes, "epsilon"))
        {
            epsilon = attributes.at("epsilon").f();
        }
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
359
        return mm->add_instruction(op, std::move(args));
Khalique's avatar
Khalique committed
360
361
    }

362
    instruction_ref
363
    parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
364
    {
365
        uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
366
367
368
        auto l0       = mm->add_instruction(
            make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]);
        return mm->add_instruction(make_op("add"), args[0], l0);
369
370
    }

371
372
373
    instruction_ref parse_cast(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
374
375
    {
        shape::type_t type = parse_type(attributes.at("DstT").type());
376
        return mm->add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
Khalique's avatar
Khalique committed
377
378
    }

379
380
381
    instruction_ref parse_concat(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
382
383
    {
        // get index for axis within args
384
        size_t axis_idx = attributes.at("N").i();
Shucai Xiao's avatar
Shucai Xiao committed
385
        int64_t axis    = args[axis_idx]->eval().at<int64_t>();
Khalique's avatar
Khalique committed
386
        op::concat op{axis};
387
        // return only first N arguments (assuming last index is the axis value)
388
        return mm->add_instruction(
Paul's avatar
Paul committed
389
            op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
Khalique's avatar
Khalique committed
390
391
392
393
    }

    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
394
                                   const std::vector<instruction_ref>&) const
Khalique's avatar
Khalique committed
395
    {
Paul's avatar
Paul committed
396
        literal v = parse_tensor(attributes.at("value").tensor());
397
        return mm->add_literal(v);
Khalique's avatar
Khalique committed
398
399
    }

400
401
402
    instruction_ref parse_conv(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
403
404
405
406
    {
        op::convolution op;
        if(contains(attributes, "strides"))
        {
407
            std::vector<size_t> stride;
408
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
409
            reorder_data(stride);
410
411
            if(stride.size() != 4)
            {
412
                MIGRAPHX_THROW("strides should have 4 values");
413
            }
414
415
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
Khalique's avatar
Khalique committed
416
417
418
        }
        if(contains(attributes, "dilations"))
        {
419
            std::vector<size_t> dilation;
420
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
421
            reorder_data(dilation);
422
423
424
425
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
426
427
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
Khalique's avatar
Khalique committed
428
        }
Khalique's avatar
Khalique committed
429

Paul's avatar
Paul committed
430
        auto weights = to_kcxy(args[1]);
Paul's avatar
Paul committed
431
        auto l0      = args[0];
Khalique's avatar
Khalique committed
432
433
434
435
436
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
437
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
438
439
440
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];
Khalique's avatar
Khalique committed
441
442
443

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
444
445
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
Khalique's avatar
Khalique committed
446
447
448
449

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
450
                    l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
Khalique's avatar
Khalique committed
451
452
453
                }
                else
                {
Khalique's avatar
Khalique committed
454
455
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
456
                }
457
458
459
            }
            else if(pad_mode.find("VALID") != std::string::npos)
            {
460
                op.padding_mode = op::padding_mode_t::valid;
Khalique's avatar
Khalique committed
461
            }
Khalique's avatar
Khalique committed
462
            else if(pad_mode.find("EXPLICIT") != std::string::npos)
Khalique's avatar
Khalique committed
463
            {
464
                std::vector<size_t> padding;
465
                copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
Khalique's avatar
Khalique committed
466
467
468
469
470
471
472
473
474
475
476
477
                if(padding.size() != 4)
                {
                    MIGRAPHX_THROW("padding should have 4 values");
                }
                if(padding[0] != padding[2] || padding[1] != padding[3])
                {
                    MIGRAPHX_THROW("migraphx does not support asymetric padding");
                }
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
kahmed10's avatar
kahmed10 committed
478
        return mm->add_instruction(op, {l0, weights});
Khalique's avatar
Khalique committed
479
480
    }

Khalique's avatar
Khalique committed
481
482
    instruction_ref parse_depthwiseconv(const std::string&,
                                        attribute_map attributes,
483
                                        std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
484
485
486
    {
        op::convolution op;
        size_t num_channels = args[0]->get_shape().lens()[1];
Khalique's avatar
Khalique committed
487
        op.group            = num_channels;
Khalique's avatar
Khalique committed
488

Khalique's avatar
Khalique committed
489
490
        if(contains(attributes, "strides"))
        {
491
            std::vector<size_t> stride;
492
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
493
            reorder_data(stride);
494
495
            if(stride.size() != 4)
            {
496
                MIGRAPHX_THROW("strides should have 4 values");
497
            }
498
499
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
Khalique's avatar
Khalique committed
500
        }
Paul's avatar
Paul committed
501
502

        auto weights = to_kcxy(args[1]);
Khalique's avatar
Khalique committed
503
504
        if(contains(attributes, "dilations"))
        {
505
            std::vector<size_t> dilation;
506
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
507
            reorder_data(dilation);
508
509
510
511
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
512
513
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
Khalique's avatar
Khalique committed
514
515
        }

Khalique's avatar
Khalique committed
516
        auto l0 = args[0];
Khalique's avatar
Khalique committed
517
518
519
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
Khalique's avatar
Khalique committed
520

Khalique's avatar
Khalique committed
521
522
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
523
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
524
525
526
527
528
529
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
530
531
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
Khalique's avatar
Khalique committed
532
533
534
535

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
536
                    l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
Khalique's avatar
Khalique committed
537
538
539
                }
                else
                {
Khalique's avatar
Khalique committed
540
541
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
542
                }
Khalique's avatar
Khalique committed
543
            }
Khalique's avatar
Khalique committed
544
            else if(pad_mode.find("VALID") != std::string::npos)
Khalique's avatar
Khalique committed
545
            {
Khalique's avatar
Khalique committed
546
                op.padding_mode = op::padding_mode_t::valid;
Khalique's avatar
Khalique committed
547
548
            }
        }
Khalique's avatar
Khalique committed
549

Khalique's avatar
Khalique committed
550
551
        std::vector<int64_t> new_weights_shape;
        copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
Khalique's avatar
Khalique committed
552
553
554
555

        // weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
        // out_channels is equal to the multiplier. Adjust by inserting a reshape and
        // setting in_channels to 1
Khalique's avatar
Khalique committed
556
        int64_t multiplier   = new_weights_shape[0];
Khalique's avatar
Khalique committed
557
558
559
        int64_t out_channels = num_channels * multiplier;
        new_weights_shape[0] = out_channels;
        new_weights_shape[1] = 1;
Paul's avatar
Paul committed
560
        // Make sure weights are contiguous before doing reshape
561
562
        auto new_weights = mm->add_instruction(make_op("reshape", {{"dims", new_weights_shape}}),
                                               make_contiguous(weights));
Khalique's avatar
Khalique committed
563

564
        return mm->add_instruction(op, {l0, new_weights});
Khalique's avatar
Khalique committed
565
566
    }

567
568
569
    instruction_ref parse_expanddims(const std::string&,
                                     const attribute_map&,
                                     std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
570
571
    {
        std::vector<size_t> input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
572
        std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
Khalique's avatar
Khalique committed
573
        size_t num_dims = input_dims.size();
574
        int32_t dim     = args[1]->eval().at<int32_t>();
Khalique's avatar
Khalique committed
575
576

        if(dim < 0)
Khalique's avatar
Khalique committed
577
578
579
580
581
582
583
        {
            new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
        }
        else
        {
            new_dims.insert(new_dims.begin() + dim, 1);
        }
584
        return mm->add_instruction(make_op("reshape", {{"dims", new_dims}}), args[0]);
Khalique's avatar
Khalique committed
585
586
    }

Khalique's avatar
Khalique committed
587
    instruction_ref
588
    parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
589
590
591
    {
        int axis = args[2]->eval().at<int32_t>();
        op::gather op{axis};
592
        return mm->add_instruction(op, {args[0], args[1]});
Khalique's avatar
Khalique committed
593
594
    }

595
596
597
    instruction_ref parse_matmul(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
598
599
600
    {
        bool transa = false;
        bool transb = false;
Khalique's avatar
Khalique committed
601

602
603
604
605
606
607
        if(contains(attributes, "transpose_a"))
        {
            transa = attributes.at("transpose_a").b();
        }
        if(contains(attributes, "transpose_b"))
        {
Khalique's avatar
Khalique committed
608
            transb = attributes.at("transpose_b").b();
609
610
        }

Khalique's avatar
Khalique committed
611
612
613
614
615
616
617
618
619
        if(contains(attributes, "adj_x"))
        {
            transa = attributes.at("adj_x").b();
        }
        if(contains(attributes, "adj_y"))
        {
            transb = attributes.at("adj_y").b();
        }

620
621
622
        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
Khalique's avatar
Khalique committed
623
        std::iter_swap(perm.end() - 1, perm.end() - 2);
624

625
626
627
628
        auto l1 = (transa) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
                           : args[0];
        auto l2 = (transb) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
                           : args[1];
629

630
        return mm->add_instruction(make_op("dot"), l1, l2);
631
632
    }

633
634
635
    instruction_ref parse_mean(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
636
    {
Khalique's avatar
Khalique committed
637
638
        bool keep_dims = attributes.at("keep_dims").b();
        auto axes      = args[1]->eval().get<int32_t>().to_vector<int64_t>();
Khalique's avatar
Khalique committed
639
640

        if(keep_dims)
Khalique's avatar
Khalique committed
641
        {
642
            return mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
643
644
645
        }
        else
        {
646
647
            auto ins = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
            return mm->add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
Khalique's avatar
Khalique committed
648
649
650
        }
    }

651
652
653
    instruction_ref parse_onehot(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
654
    {
Khalique's avatar
Khalique committed
655
656
        size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());

Khalique's avatar
Khalique committed
657
        int64_t axis    = -1;
Khalique's avatar
Khalique committed
658
659
        float on_value  = args[2]->eval().at<float>();
        float off_value = args[3]->eval().at<float>();
Khalique's avatar
Khalique committed
660

Khalique's avatar
Khalique committed
661
        std::vector<float> depth_input(depth * depth, off_value);
Khalique's avatar
Khalique committed
662
663
        for(int i = 0; i < depth; i++)
        {
Khalique's avatar
Khalique committed
664
            depth_input[depth * i + i] = on_value;
Khalique's avatar
Khalique committed
665
        }
Khalique's avatar
Khalique committed
666

Khalique's avatar
Khalique committed
667
        if(contains(attributes, "axis"))
Khalique's avatar
Khalique committed
668
669
670
            axis = attributes.at("axis").i();
        if(axis == -1)
        {
Khalique's avatar
Khalique committed
671
            shape s{shape::float_type, {depth, depth}};
672
            auto l0 = mm->add_literal({s, depth_input});
673
            return mm->add_instruction(make_op("gather", {{"axis", 0}}), {l0, args[0]});
Khalique's avatar
Khalique committed
674
675
676
677
        }
        MIGRAPHX_THROW("MIGraphX does not support axis != -1");
    }

Khalique's avatar
Khalique committed
678
679
    instruction_ref parse_pack(const std::string&,
                               const attribute_map& attributes,
680
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
681
682
683
684
685
686
    {
        // reinterpret as unsqueeze with concat
        std::vector<instruction_ref> unsqueezed_args;
        int64_t axis = 0;
        if(contains(attributes, "axis"))
            axis = attributes.at("axis").i();
687
688
689
        size_t input_size = args.front()->get_shape().lens().size();
        if(axis > input_size)
        {
Khalique's avatar
Khalique committed
690
691
            MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
                           " must be smaller than input size " + to_string(input_size));
692
693
        }

Khalique's avatar
Khalique committed
694
695
696
697
        std::transform(
            args.begin(),
            args.end(),
            std::back_inserter(unsqueezed_args),
698
699
700
701
            [&](instruction_ref arg) {
                return mm->add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg);
            });
        return to_nhwc(mm->add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
Khalique's avatar
Khalique committed
702
703
    }

Khalique's avatar
Khalique committed
704
    instruction_ref
705
    parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
706
707
708
    {
        size_t ndims = args.front()->get_shape().lens().size();

Khalique's avatar
Khalique committed
709
710
        // in tf, the paddings are arranged as a 2d shape (ndims, 2),
        // the last dim contains the left padding and right padding respectively
Khalique's avatar
Khalique committed
711
        std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
Paul's avatar
Paul committed
712
        auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
Khalique's avatar
Khalique committed
713
        for(size_t i = 0; i < 2 * ndims; i += 2)
Khalique's avatar
Khalique committed
714
        {
Khalique's avatar
Khalique committed
715
716
            pad_per_dim[i / 2].first  = tf_padding[i];
            pad_per_dim[i / 2].second = tf_padding[i + 1];
Khalique's avatar
Khalique committed
717
718
719
720
        }
        reorder_data(pad_per_dim);

        op::pad op;
Khalique's avatar
Khalique committed
721
722
        std::vector<int64_t> pads(ndims * 2);
        for(size_t i = 0; i < ndims; i++)
Khalique's avatar
Khalique committed
723
        {
Khalique's avatar
Khalique committed
724
725
            pads[i]         = pad_per_dim[i].first;
            pads[i + ndims] = pad_per_dim[i].second;
Khalique's avatar
Khalique committed
726
727
        }
        op.pads = pads;
728
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
729
730
    }

731
732
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
733
                                  std::vector<instruction_ref> args) const
734
735
    {
        op::pooling op{starts_with(name, "Max") ? "max" : "average"};
Khalique's avatar
Khalique committed
736

737
738
        if(contains(attributes, "strides"))
        {
739
            std::vector<size_t> stride;
740
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
741
            reorder_data(stride);
742
743
744
745
            if(stride.size() != 4)
            {
                MIGRAPHX_THROW("strides should have 4 values");
            }
746
747
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
748
749
750
        }
        if(contains(attributes, "ksize"))
        {
751
            std::vector<size_t> ksize;
752
            copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
753
            reorder_data(ksize);
754
755
756
            if(ksize.size() != 4)
            {
                MIGRAPHX_THROW("ksize should have 4 values");
Khalique's avatar
Khalique committed
757
            }
758
759
            op.lengths[0] = ksize[2];
            op.lengths[1] = ksize[3];
760
        }
Khalique's avatar
Khalique committed
761
762

        auto l0 = args[0];
Khalique's avatar
Khalique committed
763
764
765
766
767
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
768
769
                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
770
771
                calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
                calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
Khalique's avatar
Khalique committed
772
773
774
775

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
776
                    l0                           = mm->add_instruction(
777
778
779
780
                        migraphx::make_op(
                            "pad",
                            {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
                        l0);
Khalique's avatar
Khalique committed
781
782
783
                }
                else
                {
Khalique's avatar
Khalique committed
784
785
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
786
                }
Khalique's avatar
Khalique committed
787
788
            }
        }
789
        return mm->add_instruction(op, l0);
790
    }
Khalique's avatar
Khalique committed
791

kahmed10's avatar
kahmed10 committed
792
    instruction_ref
793
    parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
kahmed10's avatar
kahmed10 committed
794
795
    {
        auto input_lens = args[0]->get_shape().lens();
796
797
        auto min_val    = mm->add_literal(0.0f);
        auto max_val    = mm->add_literal(6.0f);
kahmed10's avatar
kahmed10 committed
798

799
800
801
802
803
        min_val =
            mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
        max_val =
            mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
        return mm->add_instruction(make_op("clip"), args.front(), min_val, max_val);
kahmed10's avatar
kahmed10 committed
804
805
    }

806
    instruction_ref
807
    parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
808
809
810
811
    {
        op::reshape op;
        if(args.size() != 2)
            MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
Khalique's avatar
Khalique committed
812
        auto s = args[1]->eval();
813
        s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
814
        return mm->add_instruction(op, make_contiguous(args[0]));
815
816
    }

817
818
819
    // Use a literal instruction to replace the shape since output of
    // shape operator are literals in migraphx
    instruction_ref
820
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
821
    {
822
823
824
825
826
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int32_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
        std::transform(
            arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
827
        return mm->add_literal(migraphx::literal{s, vec_shape});
Khalique's avatar
Khalique committed
828
829
    }

830
    instruction_ref
831
    parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
832
    {
Khalique's avatar
Khalique committed
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
        op::slice op;
        auto starts     = args[1]->eval().get<int32_t>().to_vector();
        auto size       = args[2]->eval().get<int32_t>().to_vector();
        auto axes       = args[0]->get_shape().lens();
        size_t num_axes = axes.size();

        op.starts = std::vector<int64_t>(starts.begin(), starts.end());
        op.ends   = std::vector<int64_t>(num_axes);
        op.axes   = std::vector<int64_t>(num_axes);
        std::iota(op.axes.begin(), op.axes.end(), 0);
        for(size_t i = 0; i < num_axes; i++)
        {
            if(size[i] == -1)
                op.ends[i] = axes[i];
            else
                op.ends[i] = starts[i] + size[i];
        }
850
        return mm->add_instruction(op, make_contiguous(args[0]));
Khalique's avatar
Khalique committed
851
852
    }

Khalique's avatar
Khalique committed
853
854
855
856
857
    // template to facilitate the logsoftmax later
    template <class Op>
    instruction_ref parse_softmax(const std::string&,
                                  const attribute_map& attributes,
                                  std::vector<instruction_ref> args)
858
    {
Khalique's avatar
Khalique committed
859
        int axis      = -1;
Khalique's avatar
Khalique committed
860
861
862
863
864
865
866
867
868
869
        auto num_dims = args[0]->get_shape().lens().size();
        if(contains(attributes, "axis"))
        {
            axis = static_cast<int>(attributes.at("axis").i());
        }
        if(axis < 0)
        {
            axis += num_dims;
        }

870
        return mm->add_instruction(Op{axis}, make_contiguous(args[0]));
871
872
    }

kahmed10's avatar
kahmed10 committed
873
874
    std::vector<instruction_ref> parse_split(const std::string&,
                                             const attribute_map& attributes,
875
                                             std::vector<instruction_ref> args) const
kahmed10's avatar
kahmed10 committed
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
    {
        bool vector_as_input = args.size() == 3;
        int num_outputs      = 1;
        auto axis_arg        = args[0];
        auto input_arg       = args[1];
        if(vector_as_input)
        {
            input_arg = args[0];
            axis_arg  = args[2];
        }

        if(contains(attributes, "num_split"))
            num_outputs = attributes.at("num_split").i();

        std::vector<int> splits(num_outputs);
        std::vector<int> slice_pos{0};
        if(vector_as_input)
        {
            splits      = args[1]->eval().get<int32_t>().to_vector();
            num_outputs = splits.size();
        }

        assert(num_outputs > 0);

        if(num_outputs == 1)
901
902
            return std::vector<instruction_ref>{
                mm->add_instruction(make_op("identity"), input_arg)};
kahmed10's avatar
kahmed10 committed
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946

        auto lens     = input_arg->get_shape().lens();
        auto num_dims = lens.size();
        int axis      = axis_arg->eval().at<int32_t>();

        // ensure split is made evenly if "num_split" is used
        assert(vector_as_input or lens[axis] % num_outputs == 0);

        auto split_size = lens[axis] / num_outputs;

        // push back first end point of slice
        if(vector_as_input)
        {
            slice_pos.push_back(splits[0]);
        }
        else
        {
            slice_pos.push_back(split_size);
        }

        // calculate remaining end points for each slice
        for(auto i = 1; i < num_outputs; i++)
        {
            if(vector_as_input)
            {
                splits[i] += splits[i - 1];
                slice_pos.push_back(splits[i]);
            }
            else
            {
                slice_pos.push_back((i + 1) * split_size);
            }
        }
        std::vector<instruction_ref> result;
        for(auto i = 0; i < num_outputs; i++)
        {
            op::slice op;
            op.axes = std::vector<int64_t>(num_dims);
            std::iota(op.axes.begin(), op.axes.end(), 0);
            op.starts = std::vector<int64_t>(num_dims, 0);
            op.ends   = std::vector<int64_t>(lens.begin(), lens.end());

            op.starts[axis] = slice_pos[i];
            op.ends[axis]   = slice_pos[i + 1];
947
            result.push_back(mm->add_instruction(op, input_arg));
kahmed10's avatar
kahmed10 committed
948
949
950
951
        }
        return result;
    }

Khalique's avatar
Khalique committed
952
953
    instruction_ref parse_squeeze(const std::string&,
                                  const attribute_map& attributes,
954
                                  std::vector<instruction_ref> args) const
955
956
    {
        op::squeeze op;
Khalique's avatar
Khalique committed
957
        auto input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
958
        auto axes       = attributes.at("squeeze_dims").list().i();
959
        copy(axes, std::back_inserter(op.axes));
Khalique's avatar
Khalique committed
960

961
962
        if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
        {
Khalique's avatar
Khalique committed
963
            for(size_t i = 0; i < input_dims.size(); i++)
964
            {
Khalique's avatar
Khalique committed
965
                if(input_dims.at(i) == 1)
966
967
968
969
                {
                    op.axes.push_back(i);
                }
            }
970
        }
971
        return mm->add_instruction(op, make_contiguous(args[0]));
972
973
    }

Khalique's avatar
Khalique committed
974
975
976
    instruction_ref parse_stridedslice(const std::string&,
                                       const attribute_map& attributes,
                                       std::vector<instruction_ref> args)
977
978
    {
        op::slice op;
Khalique's avatar
Khalique committed
979
980
981
982
        auto starts              = args[1]->eval().get<int32_t>().to_vector();
        auto ends                = args[2]->eval().get<int32_t>().to_vector();
        auto l0                  = args[0];
        size_t num_axes          = l0->get_shape().lens().size();
983
        std::vector<size_t> axes = l0->get_shape().lens();
984

Khalique's avatar
Khalique committed
985
986
987
988
        op.starts = std::vector<int64_t>(starts.begin(), starts.end());
        op.ends   = std::vector<int64_t>(ends.begin(), ends.end());
        op.axes   = std::vector<int64_t>(num_axes);
        std::iota(op.axes.begin(), op.axes.end(), 0);
Khalique's avatar
Khalique committed
989
990
        uint32_t begin_mask       = 0;
        uint32_t end_mask         = 0;
991
        uint32_t shrink_axis_mask = 0;
Khalique's avatar
Khalique committed
992
        uint32_t bitwise_compare  = 1;
993
994
        std::vector<int64_t> squeeze_axes;

Khalique's avatar
Khalique committed
995
996
997
998
999
1000
        if(contains(attributes, "begin_mask"))
            begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());

        if(contains(attributes, "end_mask"))
            end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());

1001
        if(contains(attributes, "shrink_axis_mask"))
1002
            shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
1003

Khalique's avatar
Khalique committed
1004
        std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
Khalique's avatar
Khalique committed
1005
        std::vector<int64_t> end_axes   = get_axes_from_mask(num_axes, end_mask);
Khalique's avatar
Khalique committed
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018

        for(size_t i = 0; i < num_axes; i++)
        {
            if(begin_axes.at(i) == 1)
            {
                op.starts.at(i) = 0;
            }
            if(end_axes.at(i) == 1)
            {
                op.ends.at(i) = axes.at(i);
            }
        }

1019
        auto l1 = mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
1020
        if(shrink_axis_mask == 0)
1021
            return l1;
Khalique's avatar
Khalique committed
1022

Khalique's avatar
Khalique committed
1023
        for(size_t i = 0; i < num_axes; i++)
1024
        {
1025
            // the LSB corresponds to axis 0 when determining which axes to squeeze
Khalique's avatar
Khalique committed
1026
            if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
1027
1028
                squeeze_axes.push_back(i);
        }
Khalique's avatar
Khalique committed
1029

1030
        return mm->add_instruction(make_op("squeeze", {{"axes", squeeze_axes}}), l1);
1031
1032
    }

1033
1034
1035
    instruction_ref parse_transpose(const std::string&,
                                    const attribute_map&,
                                    std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
1036
1037
1038
1039
1040
    {
        auto perm = args[1]->eval().get<int32_t>().to_vector();
        op::transpose op;
        op.dims = std::vector<int64_t>(perm.begin(), perm.end());

1041
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
1042
1043
    }

Khalique's avatar
Khalique committed
1044
1045
1046
1047
1048
    void parse_graph(const tensorflow::GraphDef& graph)
    {
        nodes = get_nodes(graph, input_nodes);
        for(auto&& input : input_nodes)
        {
Khalique's avatar
Khalique committed
1049
            const std::string& name   = input.name();
Khalique's avatar
Khalique committed
1050
            attribute_map input_attrs = get_attributes(input);
Khalique's avatar
Khalique committed
1051
1052
            shape::type_t shape_type  = parse_type(input_attrs.at("dtype").type());
            std::vector<size_t> dims  = parse_dims(input_attrs.at("shape").shape());
Shucai Xiao's avatar
Shucai Xiao committed
1053
1054

            if(contains(map_input_dims, name))
1055
            {
Shucai Xiao's avatar
Shucai Xiao committed
1056
                dims = map_input_dims.at(name);
1057
            }
Shucai Xiao's avatar
Shucai Xiao committed
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
            else
            {
                if(is_nhwc and dims.size() >= 4)
                {
                    reorder_data(dims);
                }
                std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
                    return static_cast<int>(dim) <= 0 ? batch_size : dim;
                });
            }

Khalique's avatar
Khalique committed
1069
            shape s            = shape{shape_type, dims};
1070
            instructions[name] = to_nhwc(mm->add_parameter(name, s));
Khalique's avatar
Khalique committed
1071
1072
1073
        }
        for(auto&& p : nodes)
        {
1074
            this->parse_node(p.first);
Khalique's avatar
Khalique committed
1075
        }
1076
1077
1078

        // Needs to add a ret instruction at the end of
        // the program
Khalique's avatar
Khalique committed
1079
1080
1081
1082
1083
1084
1085
    }

    void parse_node(const std::string& name)
    {
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
Khalique's avatar
Khalique committed
1086
1087
1088
            // assert ops ignored
            if(node.op() == "Assert" or contains(name, "Assert"))
                return;
kahmed10's avatar
kahmed10 committed
1089
1090
1091
            // noOps ignored
            if(node.op() == "NoOp" or contains(name, "NoOp"))
                return;
Khalique's avatar
Khalique committed
1092
1093
1094
1095
            std::vector<instruction_ref> args;

            for(auto&& input : node.input())
            {
Khalique's avatar
Khalique committed
1096
1097
1098
                // control dependencies (signified by ^ before the name) are ignored
                if(contains(input, "^"))
                    continue;
Khalique's avatar
Khalique committed
1099
1100
                if(nodes.count(input) > 0)
                {
kahmed10's avatar
kahmed10 committed
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
                    std::string iname;
                    // input was from a node with multiple outputs
                    if(contains(input, ':'))
                    {
                        iname = input.substr(0, input.find(':'));
                    }
                    else
                    {
                        iname = get_name(nodes.at(input));
                    }
Khalique's avatar
Khalique committed
1111
1112
                    assert(name != iname);
                    this->parse_node(iname);
kahmed10's avatar
kahmed10 committed
1113
                    args.push_back(instructions.at(input));
Khalique's avatar
Khalique committed
1114
1115
1116
1117
1118
1119
                }
                else
                {
                    args.push_back(instructions.at(input));
                }
            }
kahmed10's avatar
kahmed10 committed
1120
1121

            std::vector<instruction_ref> result;
Khalique's avatar
Khalique committed
1122
1123
            if(ops.count(node.op()) == 0)
            {
1124
                result.push_back(mm->add_instruction(op::unknown{node.op()}, args));
Khalique's avatar
Khalique committed
1125
1126
1127
            }
            else
            {
kahmed10's avatar
kahmed10 committed
1128
1129
1130
1131
1132
1133
1134
1135
1136
                result = ops[node.op()](get_attributes(node), args);
            }

            assert(!result.empty());
            // First output has no ":" delimiter
            instructions[name] = result.front();
            for(size_t i = 1; i < result.size(); i++)
            {
                instructions[name + ":" + std::to_string(i)] = result.at(i);
Khalique's avatar
Khalique committed
1137
1138
1139
1140
            }
        }
    }

1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
    void parse_from(std::istream& is)
    {
        tensorflow::GraphDef graph;
        if(graph.ParseFromIstream(&is))
        {
            this->parse_graph(graph);
        }
        else
        {
            throw std::runtime_error("Failed reading tf file");
        }
    }

Khalique's avatar
Khalique committed
1154
1155
1156
    static attribute_map get_attributes(const tensorflow::NodeDef& node)
    {
        attribute_map result;
Khalique's avatar
Khalique committed
1157
        for(auto&& attr : node.attr())
Khalique's avatar
Khalique committed
1158
1159
1160
1161
1162
1163
        {
            result[attr.first] = attr.second;
        }
        return result;
    }

Khalique's avatar
Khalique committed
1164
    static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
Khalique's avatar
Khalique committed
1165

Khalique's avatar
Khalique committed
1166
1167
    static node_map get_nodes(const tensorflow::GraphDef& graph,
                              std::vector<tensorflow::NodeDef>& input_nodes)
Khalique's avatar
Khalique committed
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
    {
        node_map result;
        for(auto&& node : graph.node())
        {
            auto node_name = get_name(node);
            // assume each node in graph has an associated name
            if(node_name.empty())
                MIGRAPHX_THROW("tf node with no name found");
            result[node_name] = node;
            if(node.op() == "Placeholder")
            {
                input_nodes.push_back(node);
            }
        }
        return result;
    }

    static shape::type_t parse_type(const tensorflow::DataType t)
    {
        shape::type_t shape_type{};
        switch(t)
        {
        case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
        case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
        case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
        case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
        case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
Paul's avatar
Paul committed
1195
1196
1197
1198
        case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
        case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
        case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
        case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
Paul's avatar
Paul committed
1199
        case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
1200
1201
1202

        case tensorflow::DataType::DT_INVALID:
        case tensorflow::DataType::DT_UINT8:
Khalique's avatar
Khalique committed
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
        case tensorflow::DataType::DT_STRING:
        case tensorflow::DataType::DT_COMPLEX64:
        case tensorflow::DataType::DT_BOOL:
        case tensorflow::DataType::DT_QINT8:
        case tensorflow::DataType::DT_QUINT8:
        case tensorflow::DataType::DT_QINT32:
        case tensorflow::DataType::DT_BFLOAT16:
        case tensorflow::DataType::DT_QINT16:
        case tensorflow::DataType::DT_QUINT16:
        case tensorflow::DataType::DT_COMPLEX128:
        case tensorflow::DataType::DT_RESOURCE:
        case tensorflow::DataType::DT_VARIANT:
Khalique's avatar
Khalique committed
1215
        // tf pb should not use these types
Paul's avatar
Paul committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
        case tensorflow::DataType::DT_FLOAT_REF:
        case tensorflow::DataType::DT_DOUBLE_REF:
        case tensorflow::DataType::DT_INT32_REF:
        case tensorflow::DataType::DT_UINT8_REF:
        case tensorflow::DataType::DT_INT16_REF:
        case tensorflow::DataType::DT_INT8_REF:
        case tensorflow::DataType::DT_STRING_REF:
        case tensorflow::DataType::DT_COMPLEX64_REF:
        case tensorflow::DataType::DT_INT64_REF:
        case tensorflow::DataType::DT_BOOL_REF:
        case tensorflow::DataType::DT_QINT8_REF:
        case tensorflow::DataType::DT_QUINT8_REF:
        case tensorflow::DataType::DT_QINT32_REF:
        case tensorflow::DataType::DT_BFLOAT16_REF:
        case tensorflow::DataType::DT_QINT16_REF:
        case tensorflow::DataType::DT_QUINT16_REF:
        case tensorflow::DataType::DT_UINT16_REF:
        case tensorflow::DataType::DT_COMPLEX128_REF:
        case tensorflow::DataType::DT_HALF_REF:
        case tensorflow::DataType::DT_RESOURCE_REF:
        case tensorflow::DataType::DT_VARIANT_REF:
        case tensorflow::DataType::DT_UINT32_REF:
        case tensorflow::DataType::DT_UINT64_REF:
Paul's avatar
Paul committed
1239
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
Khalique's avatar
Khalique committed
1240
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
Khalique's avatar
Khalique committed
1241
1242
1243
1244
        }
        return shape_type;
    }

Khalique's avatar
Khalique committed
1245
    static literal parse_tensor(const tensorflow::TensorProto& t)
Khalique's avatar
Khalique committed
1246
1247
    {
        std::vector<size_t> dims = parse_dims(t.tensor_shape());
1248
        size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
Khalique's avatar
Khalique committed
1249
1250
        if(!t.tensor_content().empty()) // has raw data
        {
Khalique's avatar
Khalique committed
1251
            const std::string& s = t.tensor_content();
Khalique's avatar
Khalique committed
1252
1253
            switch(t.dtype())
            {
Khalique's avatar
Khalique committed
1254
1255
            case tensorflow::DataType::DT_FLOAT:
                return literal{{shape::float_type, dims}, s.data()};
Paul's avatar
Paul committed
1256
            case tensorflow::DataType::DT_BOOL:
1257
            case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1258
1259
            case tensorflow::DataType::DT_UINT16:
            case tensorflow::DataType::DT_INT16:
1260
                return literal{{shape::int16_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1261
1262
1263
1264
            case tensorflow::DataType::DT_INT32:
                return literal{{shape::int32_type, dims}, s.data()};
            case tensorflow::DataType::DT_INT64:
                return literal{{shape::int64_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1265
            case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1266
1267
            case tensorflow::DataType::DT_DOUBLE:
                return literal{{shape::double_type, dims}, s.data()};
Paul's avatar
Paul committed
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
            case tensorflow::DataType::DT_INVALID:
            case tensorflow::DataType::DT_UINT8:
            case tensorflow::DataType::DT_STRING:
            case tensorflow::DataType::DT_UINT32:
            case tensorflow::DataType::DT_UINT64:
            case tensorflow::DataType::DT_COMPLEX64:
            case tensorflow::DataType::DT_COMPLEX128:
            case tensorflow::DataType::DT_QINT8:
            case tensorflow::DataType::DT_QUINT8:
            case tensorflow::DataType::DT_QINT32:
            case tensorflow::DataType::DT_BFLOAT16:
            case tensorflow::DataType::DT_QINT16:
            case tensorflow::DataType::DT_QUINT16:
            case tensorflow::DataType::DT_RESOURCE:
            case tensorflow::DataType::DT_VARIANT:
            case tensorflow::DataType::DT_FLOAT_REF:
            case tensorflow::DataType::DT_DOUBLE_REF:
            case tensorflow::DataType::DT_INT32_REF:
            case tensorflow::DataType::DT_UINT8_REF:
            case tensorflow::DataType::DT_INT16_REF:
            case tensorflow::DataType::DT_INT8_REF:
            case tensorflow::DataType::DT_STRING_REF:
            case tensorflow::DataType::DT_COMPLEX64_REF:
            case tensorflow::DataType::DT_INT64_REF:
            case tensorflow::DataType::DT_BOOL_REF:
            case tensorflow::DataType::DT_QINT8_REF:
            case tensorflow::DataType::DT_QUINT8_REF:
            case tensorflow::DataType::DT_QINT32_REF:
            case tensorflow::DataType::DT_BFLOAT16_REF:
            case tensorflow::DataType::DT_QINT16_REF:
            case tensorflow::DataType::DT_QUINT16_REF:
            case tensorflow::DataType::DT_UINT16_REF:
            case tensorflow::DataType::DT_COMPLEX128_REF:
            case tensorflow::DataType::DT_HALF_REF:
            case tensorflow::DataType::DT_RESOURCE_REF:
            case tensorflow::DataType::DT_VARIANT_REF:
            case tensorflow::DataType::DT_UINT32_REF:
            case tensorflow::DataType::DT_UINT64_REF:
Khalique's avatar
Khalique committed
1306
1307
1308
            case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
            case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
                throw std::runtime_error("");
Khalique's avatar
Khalique committed
1309
1310
1311
1312
1313
1314
            }
            MIGRAPHX_THROW("Invalid tensor type");
        }
        switch(t.dtype())
        {
        case tensorflow::DataType::DT_FLOAT:
Khalique's avatar
Khalique committed
1315
1316
            return create_literal(
                shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
Khalique's avatar
Khalique committed
1317
        case tensorflow::DataType::DT_INT8:
1318
            return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1319
        case tensorflow::DataType::DT_UINT16:
1320
            return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1321
        case tensorflow::DataType::DT_INT16:
1322
            return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1323
        case tensorflow::DataType::DT_INT32:
1324
            return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1325
        case tensorflow::DataType::DT_INT64:
Khalique's avatar
Khalique committed
1326
1327
            return create_literal(
                shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
Khalique's avatar
Khalique committed
1328
        case tensorflow::DataType::DT_BOOL:
1329
            return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
Khalique's avatar
Khalique committed
1330
        case tensorflow::DataType::DT_HALF:
Khalique's avatar
Khalique committed
1331
        {
1332
1333
            std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
            std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
Khalique's avatar
Khalique committed
1334
1335
1336
1337
1338
            std::vector<half> data_half;
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1339
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1340
        }
Khalique's avatar
Khalique committed
1341
        case tensorflow::DataType::DT_DOUBLE:
Khalique's avatar
Khalique committed
1342
            return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
Paul's avatar
Paul committed
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
        case tensorflow::DataType::DT_INVALID:
        case tensorflow::DataType::DT_UINT8:
        case tensorflow::DataType::DT_STRING:
        case tensorflow::DataType::DT_UINT32:
        case tensorflow::DataType::DT_UINT64:
        case tensorflow::DataType::DT_COMPLEX64:
        case tensorflow::DataType::DT_COMPLEX128:
        case tensorflow::DataType::DT_QINT8:
        case tensorflow::DataType::DT_QUINT8:
        case tensorflow::DataType::DT_QINT32:
        case tensorflow::DataType::DT_BFLOAT16:
        case tensorflow::DataType::DT_QINT16:
        case tensorflow::DataType::DT_QUINT16:
        case tensorflow::DataType::DT_RESOURCE:
        case tensorflow::DataType::DT_VARIANT:
        case tensorflow::DataType::DT_FLOAT_REF:
        case tensorflow::DataType::DT_DOUBLE_REF:
        case tensorflow::DataType::DT_INT32_REF:
        case tensorflow::DataType::DT_UINT8_REF:
        case tensorflow::DataType::DT_INT16_REF:
        case tensorflow::DataType::DT_INT8_REF:
        case tensorflow::DataType::DT_STRING_REF:
        case tensorflow::DataType::DT_COMPLEX64_REF:
        case tensorflow::DataType::DT_INT64_REF:
        case tensorflow::DataType::DT_BOOL_REF:
        case tensorflow::DataType::DT_QINT8_REF:
        case tensorflow::DataType::DT_QUINT8_REF:
        case tensorflow::DataType::DT_QINT32_REF:
        case tensorflow::DataType::DT_BFLOAT16_REF:
        case tensorflow::DataType::DT_QINT16_REF:
        case tensorflow::DataType::DT_QUINT16_REF:
        case tensorflow::DataType::DT_UINT16_REF:
        case tensorflow::DataType::DT_COMPLEX128_REF:
        case tensorflow::DataType::DT_HALF_REF:
        case tensorflow::DataType::DT_RESOURCE_REF:
        case tensorflow::DataType::DT_VARIANT_REF:
        case tensorflow::DataType::DT_UINT32_REF:
        case tensorflow::DataType::DT_UINT64_REF:
Khalique's avatar
Khalique committed
1381
1382
1383
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
            throw std::runtime_error("");
Khalique's avatar
Khalique committed
1384
1385
1386
1387
        }
        MIGRAPHX_THROW("Invalid tensor type");
    }

1388
    template <class T>
Khalique's avatar
Khalique committed
1389
    static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
1390
                                        const size_t& shape_size)
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
    {
        std::vector<T> data_vals(shape_size);
        // check if shape has enough data values given existing fields
        if(data.size() == 1)
        {
            std::fill(data_vals.begin(), data_vals.end(), data[0]);
        }
        else
            copy(data.begin(), data.end(), std::back_inserter(data_vals));
        return data_vals;
    }

Khalique's avatar
Khalique committed
1403
1404
1405
1406
    static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
    {
        std::vector<size_t> dims;
        auto input_dims = s.dim();
Khalique's avatar
Khalique committed
1407
1408
1409
        std::transform(input_dims.begin(),
                       input_dims.end(),
                       std::back_inserter(dims),
Paul's avatar
Paul committed
1410
                       [](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); });
Khalique's avatar
Khalique committed
1411
1412
        return dims;
    }
1413
1414

    template <class T>
Khalique's avatar
Khalique committed
1415
    static literal
1416
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
1417
    {
Khalique's avatar
Khalique committed
1418
        // assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
1419
        if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
1420
            return literal{{shape_type}, data};
1421
1422
        return literal{{shape_type, dims}, data};
    }
Khalique's avatar
Khalique committed
1423
1424
};

Shucai Xiao's avatar
Shucai Xiao committed
1425
program parse_tf(const std::string& name, const tf_options& options)
Khalique's avatar
Khalique committed
1426
1427
1428
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    tf_parser parser;
Shucai Xiao's avatar
Shucai Xiao committed
1429
1430
1431
    parser.is_nhwc        = options.is_nhwc;
    parser.batch_size     = options.batch_size;
    parser.map_input_dims = options.map_input_dims;
Khalique's avatar
Khalique committed
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446

#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
1447
    parser.to_nchw(std::prev(parser.mm->end()));
Khalique's avatar
Khalique committed
1448
1449
1450
1451
1452
    return std::move(parser.prog);
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx