tf.cpp 55.9 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <array>
#include <utility>
#include <vector>

#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
Paul's avatar
Paul committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/batch_norm_inference.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unknown.hpp>
Khalique's avatar
Khalique committed
30
31
32
33
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
34
35
#include <migraphx/make_op.hpp>

Khalique's avatar
Khalique committed
36
#include <migraphx/pad_calc.hpp>
37
#include <migraphx/tune_axis.hpp>
Khalique's avatar
Khalique committed
38
39
40
41
42
43
44

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct tf_parser
{
    using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
Paul's avatar
Paul committed
45
    using node_map      = std::map<std::string, tensorflow::NodeDef>;
kahmed10's avatar
kahmed10 committed
46
47
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Khalique's avatar
Khalique committed
48

Khalique's avatar
Khalique committed
49
50
51
    node_map nodes;
    std::vector<tensorflow::NodeDef> input_nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
52
    program prog            = program();
53
    module* mm              = prog.get_main_module();
54
55
    bool is_nhwc            = true;
    unsigned int batch_size = 1;
Shucai Xiao's avatar
Shucai Xiao committed
56
57
    // Specified dims of inputs
    std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
Khalique's avatar
Khalique committed
58
59
60

    std::unordered_map<std::string, op_func> ops;

Paul's avatar
Paul committed
61
    bool should_transpose(instruction_ref ins) const
Paul's avatar
Paul committed
62
63
64
65
    {
        return is_nhwc and ins->get_shape().lens().size() == 4;
    }

66
    instruction_ref to_nhwc(instruction_ref ins) const
Paul's avatar
Paul committed
67
    {
Paul's avatar
Paul committed
68
        if(should_transpose(ins))
69
            return mm->add_instruction(make_op("transpose", {{"dims", {0, 2, 3, 1}}}), ins);
Paul's avatar
Paul committed
70
71
72
        return ins;
    }

73
    instruction_ref to_nchw(instruction_ref ins) const
Paul's avatar
Paul committed
74
    {
Paul's avatar
Paul committed
75
        if(should_transpose(ins))
76
            return mm->add_instruction(make_op("transpose", {{"dims", {0, 3, 1, 2}}}), ins);
Paul's avatar
Paul committed
77
78
79
        return ins;
    }

80
    instruction_ref to_kcxy(instruction_ref ins) const
Paul's avatar
Paul committed
81
    {
82
        return mm->add_instruction(make_op("transpose", {{"dims", {3, 2, 0, 1}}}), ins);
Paul's avatar
Paul committed
83
84
    }

85
    instruction_ref make_contiguous(instruction_ref ins) const
Paul's avatar
Paul committed
86
    {
Paul's avatar
Paul committed
87
        if(ins->get_shape().standard())
Paul's avatar
Paul committed
88
89
            return ins;
        else
90
            return mm->add_instruction(make_op("contiguous"), ins);
Paul's avatar
Paul committed
91
92
93
94
95
    }

    std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
    {
        std::vector<instruction_ref> result(args.size());
Paul's avatar
Paul committed
96
        std::transform(
Paul's avatar
Paul committed
97
            args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
Paul's avatar
Paul committed
98
99
100
        return result;
    }

kahmed10's avatar
kahmed10 committed
101
102
103
104
105
106
107
108
    std::vector<instruction_ref> to_nhwc(const std::vector<instruction_ref>& args)
    {
        std::vector<instruction_ref> result(args.size());
        std::transform(
            args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nhwc(ins); });
        return result;
    }

Khalique's avatar
Khalique committed
109
    std::vector<size_t>
110
    parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
111
    {
112
113
114
        auto attrs = attributes.at(s).list().i();
        std::vector<size_t> axes;
        copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
Khalique's avatar
Khalique committed
115
        if(is_nhwc)
116
        {
Khalique's avatar
Khalique committed
117
            std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
Khalique's avatar
Khalique committed
118
                return parse_axis(axis, num_dims);
Khalique's avatar
Khalique committed
119
            });
120
121
122
123
        }
        return axes;
    }

Khalique's avatar
Khalique committed
124
    template <class T>
125
    std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
Khalique's avatar
Khalique committed
126
127
128
    {
        if(is_nhwc)
        {
129
            std::vector<T> new_axes;
Khalique's avatar
Khalique committed
130
131
132
            std::transform(axes.begin(),
                           axes.end(),
                           std::back_inserter(new_axes),
Khalique's avatar
Khalique committed
133
                           [&](size_t axis) { return parse_axis(axis, num_dims); });
134
            return new_axes;
Khalique's avatar
Khalique committed
135
        }
136
        return axes;
Khalique's avatar
Khalique committed
137
138
    }

Khalique's avatar
Khalique committed
139
140
141
    // tf stores certain attributes such as strides, dilations, as a 4D input.
    // The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
    // This helper function reorders the data to store for the respective operator member variables.
142
    template <class T>
143
    void reorder_data(std::vector<T>& prev_data) const
144
145
    {
        std::vector<T> new_data(prev_data.size());
146
        for(size_t i = 0; i < new_data.size(); i++)
147
        {
Khalique's avatar
Khalique committed
148
            auto new_idx         = parse_axis(i, new_data.size());
149
            new_data.at(new_idx) = prev_data.at(i);
150
        }
151
152
153
154
        prev_data = new_data;
    }

    template <class T>
155
    T parse_axis(const T& dim, const size_t num_dims) const
156
    {
Khalique's avatar
Khalique committed
157
        T new_dim = dim;
Khalique's avatar
Khalique committed
158
        if(is_nhwc and num_dims >= 4)
159
160
161
        {
            switch(dim)
            {
Khalique's avatar
Khalique committed
162
163
164
165
166
            case 0: new_dim = 0; break;
            case 1: new_dim = 2; break;
            case 2: new_dim = 3; break;
            case 3: new_dim = 1; break;
            default: break;
167
168
            }
        }
Khalique's avatar
Khalique committed
169
        return new_dim;
170
171
    }

172
173
174
175
176
177
178
    std::vector<int64_t> get_axes(size_t num_axes) const
    {
        std::vector<int64_t> axes(num_axes);
        std::iota(axes.begin(), axes.end(), 0);
        return axes;
    }

Khalique's avatar
Khalique committed
179
    std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
Khalique's avatar
Khalique committed
180
    {
Khalique's avatar
Khalique committed
181
        uint32_t bitwise_compare = 1;
Khalique's avatar
Khalique committed
182
183
184
185
186
187
188
189
190
191
192
193
        std::vector<int64_t> axes;
        for(size_t i = 0; i < num_axes; i++)
        {
            // the LSB corresponds to axis 0 when determining which axes to begin
            if(((mask >> i) & bitwise_compare) == 1)
                axes.push_back(1);
            else
                axes.push_back(0);
        }
        return axes;
    }

Khalique's avatar
Khalique committed
194
195
    tf_parser()
    {
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        add_generic_op("All", make_op("identity"));
        add_generic_op("Identity", make_op("identity"));
        add_generic_op("LessEqual", make_op("identity"));
        add_generic_op("Relu", make_op("relu"));
        add_generic_op("Rsqrt", make_op("rsqrt"));
        add_generic_op("Tanh", make_op("tanh"));
        add_generic_op("StopGradient", make_op("identity"));

        add_binary_op("Add", make_op("add"));
        add_binary_op("AddV2", make_op("add"));
        add_binary_op("Mul", make_op("mul"));
        add_binary_op("Pow", make_op("pow"));
        add_binary_op("SquaredDifference", make_op("sqdiff"));
        add_binary_op("Sub", make_op("sub"));
Khalique's avatar
Khalique committed
210

211
212
        add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false);
        add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false);
213
        add_mem_op("AvgPool", &tf_parser::parse_pooling);
Khalique's avatar
Khalique committed
214
        add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
Khalique's avatar
Khalique committed
215
        add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
216
        add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
Khalique's avatar
Khalique committed
217
        add_mem_op("Cast", &tf_parser::parse_cast, false);
Paul's avatar
Paul committed
218
        add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
Khalique's avatar
Khalique committed
219
        add_mem_op("Const", &tf_parser::parse_constant);
Paul's avatar
Paul committed
220
        add_mem_op("Conv2D", &tf_parser::parse_conv);
Paul's avatar
Paul committed
221
        add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
222
        add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
Khalique's avatar
Khalique committed
223
        add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
kahmed10's avatar
kahmed10 committed
224
        add_mem_op("FusedBatchNormV3", &tf_parser::parse_batchnorm);
Khalique's avatar
Khalique committed
225
        add_mem_op("GatherV2", &tf_parser::parse_gather, false);
Paul's avatar
Paul committed
226
        add_mem_op("MatMul", &tf_parser::parse_matmul, false);
227
        add_mem_op("MaxPool", &tf_parser::parse_pooling);
Khalique's avatar
Khalique committed
228
        add_mem_op("Mean", &tf_parser::parse_mean, false);
Khalique's avatar
Khalique committed
229
        add_mem_op("OneHot", &tf_parser::parse_onehot, false);
Paul's avatar
Paul committed
230
        add_mem_op("Pack", &tf_parser::parse_pack, false);
Paul's avatar
Paul committed
231
        add_mem_op("Pad", &tf_parser::parse_pad);
kahmed10's avatar
kahmed10 committed
232
        add_mem_op("Relu6", &tf_parser::parse_relu6);
Paul's avatar
Paul committed
233
        add_mem_op("Reshape", &tf_parser::parse_reshape, false);
234
        add_mem_op("Shape", &tf_parser::parse_shape, false);
Khalique's avatar
Khalique committed
235
        add_mem_op("Slice", &tf_parser::parse_slice, false);
kahmed10's avatar
kahmed10 committed
236
237
        add_mem_op("Split", &tf_parser::parse_split, false);
        add_mem_op("SplitV", &tf_parser::parse_split, false);
Khalique's avatar
Khalique committed
238
        add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
Paul's avatar
Paul committed
239
        add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
240
        add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
Khalique's avatar
Khalique committed
241
        add_mem_op("Transpose", &tf_parser::parse_transpose, false);
Khalique's avatar
Khalique committed
242
243
    }

244
    template <class F>
kahmed10's avatar
kahmed10 committed
245
    void add_op(const std::string& name, F f, bool transpose = true)
246
    {
Paul's avatar
Paul committed
247
        if(transpose)
Paul's avatar
Paul committed
248
        {
kahmed10's avatar
kahmed10 committed
249
250
251
252
253
254
            ops.emplace(
                name,
                op_func{
                    [=](const attribute_map& attributes, const std::vector<instruction_ref>& args) {
                        return std::vector<instruction_ref>{to_nhwc(f(attributes, to_nchw(args)))};
                    }});
Paul's avatar
Paul committed
255
256
257
        }
        else
        {
kahmed10's avatar
kahmed10 committed
258
259
260
261
262
            ops.emplace(name,
                        op_func{[=](const attribute_map& attributes,
                                    const std::vector<instruction_ref>& args) {
                            return std::vector<instruction_ref>{f(attributes, args)};
                        }});
Paul's avatar
Paul committed
263
        }
264
265
    }

Khalique's avatar
Khalique committed
266
    template <class F>
Paul's avatar
Paul committed
267
    void add_mem_op(std::string name, F f, bool transpose = true)
Khalique's avatar
Khalique committed
268
    {
Paul's avatar
Paul committed
269
270
271
272
273
        add_op(name,
               [=](auto&&... xs) {
                   return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
               },
               transpose);
Khalique's avatar
Khalique committed
274
275
276
277
278
    }

    template <class T>
    void add_binary_op(std::string name, T x)
    {
Paul's avatar
Paul committed
279
280
281
282
283
284
285
286
287
        add_op(name,
               [this, x](const attribute_map&, std::vector<instruction_ref> args) {
                   if(args.size() != 2)
                       MIGRAPHX_THROW("binary operators should have 2 operands");
                   // TODO
                   // if(contains(attributes, "data_format"))
                   // {
                   //     if(is_nhwc)
                   //     {
288
                   //         l0 = mm->add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
Paul's avatar
Paul committed
289
290
291
292
293
                   //     }
                   // }
                   return add_broadcastable_binary_op(args[0], args[1], x);
               },
               false);
Khalique's avatar
Khalique committed
294
295
296
297
298
    }

    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
299
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        {
            // Example:
            // s0 = (3,2,4,5) and s1 = (2,1,1)
            //
            // In this case we need to broadcast (:,1,1) portion of
            // s1 plus broadcast the 1st dimension of s1
            // giving output_lens = (3,2,4,5)
            //
            // Another example:
            // s0 = (3,2,1,5) and s1 = (2,7,5)
            // In this case we need to broadcast the (:,:,1:,:) axis
            // of s0 plus the 1st dimension of s1 giving
            // output_lens = (3,2,7,5)
            //
            // Get lengths for both arguments
315
316
            const std::vector<size_t>* s0 = &arg0->get_shape().lens();
            const std::vector<size_t>* s1 = &arg1->get_shape().lens();
Khalique's avatar
Khalique committed
317
318
319
320
321

            // Make sure s0 is the smaller size
            if(s0->size() > s1->size())
                std::swap(s0, s1);

322
            std::vector<size_t> output_lens(*s1);
Khalique's avatar
Khalique committed
323
324
325
326
327
328
329
            auto offset = s1->size() - s0->size();
            std::transform(s0->begin(),
                           s0->end(),
                           s1->begin() + offset,
                           output_lens.begin() + offset,
                           [](auto a, auto b) { return std::max(a, b); });

330
331
332
333
            auto l0 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
                                          arg0);
            auto l1 = mm->add_instruction(make_op("multibroadcast", {{"output_lens", output_lens}}),
                                          arg1);
334
            return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
Khalique's avatar
Khalique committed
335
336
337
        }
        else
        {
338
            return to_nhwc(mm->add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
Khalique's avatar
Khalique committed
339
340
341
342
        }
    }

    template <class T>
Paul's avatar
Paul committed
343
    void add_generic_op(std::string name, T x, bool transpose = true)
Khalique's avatar
Khalique committed
344
    {
Paul's avatar
Paul committed
345
346
        add_op(name,
               [this, x](const attribute_map&, std::vector<instruction_ref> args) {
347
                   return mm->add_instruction(x, args);
Paul's avatar
Paul committed
348
349
               },
               transpose);
Khalique's avatar
Khalique committed
350
351
    }

352
353
354
355
356
357
    template <class Op>
    instruction_ref
    parse_arg_op(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        axis         = args[1]->eval().at<int64_t>();
358
        auto ins     = mm->add_instruction(Op{axis}, args.front());
359
        return mm->add_instruction(make_op("squeeze", {{"axes", {axis}}}), ins);
360
361
    }

362
363
364
    instruction_ref parse_batchnorm(const std::string&,
                                    attribute_map attributes,
                                    std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
365
    {
Khalique's avatar
Khalique committed
366
367
368
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
Khalique's avatar
Khalique committed
369
370
371
372
373
        if(contains(attributes, "epsilon"))
        {
            epsilon = attributes.at("epsilon").f();
        }
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
374
        return mm->add_instruction(op, std::move(args));
Khalique's avatar
Khalique committed
375
376
    }

377
    instruction_ref
378
    parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
379
    {
380
        uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
381
382
383
        auto l0       = mm->add_instruction(
            make_op("broadcast", {{"axis", axis}, {"dims", args[0]->get_shape().lens()}}), args[1]);
        return mm->add_instruction(make_op("add"), args[0], l0);
384
385
    }

386
387
388
    instruction_ref parse_cast(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
389
390
    {
        shape::type_t type = parse_type(attributes.at("DstT").type());
391
        return mm->add_instruction(make_op("convert", {{"target_type", type}}), std::move(args));
Khalique's avatar
Khalique committed
392
393
    }

394
395
396
    instruction_ref parse_concat(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
397
398
    {
        // get index for axis within args
399
        size_t axis_idx = attributes.at("N").i();
Shucai Xiao's avatar
Shucai Xiao committed
400
        int64_t axis    = args[axis_idx]->eval().at<int64_t>();
Khalique's avatar
Khalique committed
401
        op::concat op{axis};
402
        // return only first N arguments (assuming last index is the axis value)
403
        return mm->add_instruction(
Paul's avatar
Paul committed
404
            op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
Khalique's avatar
Khalique committed
405
406
407
408
    }

    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
409
                                   const std::vector<instruction_ref>&) const
Khalique's avatar
Khalique committed
410
    {
Paul's avatar
Paul committed
411
        literal v = parse_tensor(attributes.at("value").tensor());
412
        return mm->add_literal(v);
Khalique's avatar
Khalique committed
413
414
    }

415
416
417
    instruction_ref parse_conv(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
418
419
420
421
    {
        op::convolution op;
        if(contains(attributes, "strides"))
        {
422
            std::vector<size_t> stride;
423
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
424
            reorder_data(stride);
425
426
            if(stride.size() != 4)
            {
427
                MIGRAPHX_THROW("strides should have 4 values");
428
            }
429
430
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
Khalique's avatar
Khalique committed
431
432
433
        }
        if(contains(attributes, "dilations"))
        {
434
            std::vector<size_t> dilation;
435
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
436
            reorder_data(dilation);
437
438
439
440
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
441
442
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
Khalique's avatar
Khalique committed
443
        }
Khalique's avatar
Khalique committed
444

Paul's avatar
Paul committed
445
        auto weights = to_kcxy(args[1]);
Paul's avatar
Paul committed
446
        auto l0      = args[0];
Khalique's avatar
Khalique committed
447
448
449
450
451
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
452
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
453
454
455
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];
Khalique's avatar
Khalique committed
456
457
458

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
459
460
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
Khalique's avatar
Khalique committed
461
462
463
464

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
465
                    l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
Khalique's avatar
Khalique committed
466
467
468
                }
                else
                {
Khalique's avatar
Khalique committed
469
470
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
471
                }
472
473
474
            }
            else if(pad_mode.find("VALID") != std::string::npos)
            {
475
                op.padding_mode = op::padding_mode_t::valid;
Khalique's avatar
Khalique committed
476
            }
Khalique's avatar
Khalique committed
477
            else if(pad_mode.find("EXPLICIT") != std::string::npos)
Khalique's avatar
Khalique committed
478
            {
479
                std::vector<size_t> padding;
480
                copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
Khalique's avatar
Khalique committed
481
482
483
484
485
486
487
488
489
490
491
492
                if(padding.size() != 4)
                {
                    MIGRAPHX_THROW("padding should have 4 values");
                }
                if(padding[0] != padding[2] || padding[1] != padding[3])
                {
                    MIGRAPHX_THROW("migraphx does not support asymetric padding");
                }
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
kahmed10's avatar
kahmed10 committed
493
        return mm->add_instruction(op, {l0, weights});
Khalique's avatar
Khalique committed
494
495
    }

Khalique's avatar
Khalique committed
496
497
    instruction_ref parse_depthwiseconv(const std::string&,
                                        attribute_map attributes,
498
                                        std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
499
500
501
    {
        op::convolution op;
        size_t num_channels = args[0]->get_shape().lens()[1];
Khalique's avatar
Khalique committed
502
        op.group            = num_channels;
Khalique's avatar
Khalique committed
503

Khalique's avatar
Khalique committed
504
505
        if(contains(attributes, "strides"))
        {
506
            std::vector<size_t> stride;
507
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
508
            reorder_data(stride);
509
510
            if(stride.size() != 4)
            {
511
                MIGRAPHX_THROW("strides should have 4 values");
512
            }
513
514
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
Khalique's avatar
Khalique committed
515
        }
Paul's avatar
Paul committed
516
517

        auto weights = to_kcxy(args[1]);
Khalique's avatar
Khalique committed
518
519
        if(contains(attributes, "dilations"))
        {
520
            std::vector<size_t> dilation;
521
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
522
            reorder_data(dilation);
523
524
525
526
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
527
528
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
Khalique's avatar
Khalique committed
529
530
        }

Khalique's avatar
Khalique committed
531
        auto l0 = args[0];
Khalique's avatar
Khalique committed
532
533
534
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
Khalique's avatar
Khalique committed
535

Khalique's avatar
Khalique committed
536
537
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
538
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
539
540
541
542
543
544
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
545
546
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
Khalique's avatar
Khalique committed
547
548
549
550

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
551
                    l0 = mm->add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
Khalique's avatar
Khalique committed
552
553
554
                }
                else
                {
Khalique's avatar
Khalique committed
555
556
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
557
                }
Khalique's avatar
Khalique committed
558
            }
Khalique's avatar
Khalique committed
559
            else if(pad_mode.find("VALID") != std::string::npos)
Khalique's avatar
Khalique committed
560
            {
Khalique's avatar
Khalique committed
561
                op.padding_mode = op::padding_mode_t::valid;
Khalique's avatar
Khalique committed
562
563
            }
        }
Khalique's avatar
Khalique committed
564

Khalique's avatar
Khalique committed
565
566
        std::vector<int64_t> new_weights_shape;
        copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
Khalique's avatar
Khalique committed
567
568
569
570

        // weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
        // out_channels is equal to the multiplier. Adjust by inserting a reshape and
        // setting in_channels to 1
Khalique's avatar
Khalique committed
571
        int64_t multiplier   = new_weights_shape[0];
Khalique's avatar
Khalique committed
572
573
574
        int64_t out_channels = num_channels * multiplier;
        new_weights_shape[0] = out_channels;
        new_weights_shape[1] = 1;
Paul's avatar
Paul committed
575
        // Make sure weights are contiguous before doing reshape
576
577
        auto new_weights = mm->add_instruction(make_op("reshape", {{"dims", new_weights_shape}}),
                                               make_contiguous(weights));
Khalique's avatar
Khalique committed
578

579
        return mm->add_instruction(op, {l0, new_weights});
Khalique's avatar
Khalique committed
580
581
    }

582
583
584
    instruction_ref parse_expanddims(const std::string&,
                                     const attribute_map&,
                                     std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
585
586
    {
        std::vector<size_t> input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
587
        std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
Khalique's avatar
Khalique committed
588
        size_t num_dims = input_dims.size();
589
        int32_t dim     = args[1]->eval().at<int32_t>();
Khalique's avatar
Khalique committed
590
591

        if(dim < 0)
Khalique's avatar
Khalique committed
592
593
594
595
596
597
598
        {
            new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
        }
        else
        {
            new_dims.insert(new_dims.begin() + dim, 1);
        }
599
        return mm->add_instruction(make_op("reshape", {{"dims", new_dims}}), args[0]);
Khalique's avatar
Khalique committed
600
601
    }

Khalique's avatar
Khalique committed
602
    instruction_ref
603
    parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
604
605
606
    {
        int axis = args[2]->eval().at<int32_t>();
        op::gather op{axis};
607
        return mm->add_instruction(op, {args[0], args[1]});
Khalique's avatar
Khalique committed
608
609
    }

610
611
612
    instruction_ref parse_matmul(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
613
614
615
    {
        bool transa = false;
        bool transb = false;
Khalique's avatar
Khalique committed
616

617
618
619
620
621
622
        if(contains(attributes, "transpose_a"))
        {
            transa = attributes.at("transpose_a").b();
        }
        if(contains(attributes, "transpose_b"))
        {
Khalique's avatar
Khalique committed
623
            transb = attributes.at("transpose_b").b();
624
625
        }

Khalique's avatar
Khalique committed
626
627
628
629
630
631
632
633
634
        if(contains(attributes, "adj_x"))
        {
            transa = attributes.at("adj_x").b();
        }
        if(contains(attributes, "adj_y"))
        {
            transb = attributes.at("adj_y").b();
        }

635
636
637
        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
Khalique's avatar
Khalique committed
638
        std::iter_swap(perm.end() - 1, perm.end() - 2);
639

640
641
642
643
        auto l1 = (transa) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[0])
                           : args[0];
        auto l2 = (transb) ? mm->add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
                           : args[1];
644

645
        return mm->add_instruction(make_op("dot"), l1, l2);
646
647
    }

648
649
650
    instruction_ref parse_mean(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
651
    {
Khalique's avatar
Khalique committed
652
653
        bool keep_dims = attributes.at("keep_dims").b();
        auto axes      = args[1]->eval().get<int32_t>().to_vector<int64_t>();
Khalique's avatar
Khalique committed
654
655

        if(keep_dims)
Khalique's avatar
Khalique committed
656
        {
657
            return mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
658
659
660
        }
        else
        {
661
662
            auto ins = mm->add_instruction(make_op("reduce_mean", {{"axes", axes}}), args[0]);
            return mm->add_instruction(make_op("squeeze", {{"axes", axes}}), ins);
Khalique's avatar
Khalique committed
663
664
665
        }
    }

666
667
668
    instruction_ref parse_onehot(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
669
    {
Khalique's avatar
Khalique committed
670
671
        size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());

Khalique's avatar
Khalique committed
672
        int64_t axis    = -1;
Khalique's avatar
Khalique committed
673
674
        float on_value  = args[2]->eval().at<float>();
        float off_value = args[3]->eval().at<float>();
Khalique's avatar
Khalique committed
675

Khalique's avatar
Khalique committed
676
        std::vector<float> depth_input(depth * depth, off_value);
Khalique's avatar
Khalique committed
677
678
        for(int i = 0; i < depth; i++)
        {
Khalique's avatar
Khalique committed
679
            depth_input[depth * i + i] = on_value;
Khalique's avatar
Khalique committed
680
        }
Khalique's avatar
Khalique committed
681

Khalique's avatar
Khalique committed
682
        if(contains(attributes, "axis"))
Khalique's avatar
Khalique committed
683
684
685
            axis = attributes.at("axis").i();
        if(axis == -1)
        {
Khalique's avatar
Khalique committed
686
            shape s{shape::float_type, {depth, depth}};
687
            auto l0 = mm->add_literal({s, depth_input});
688
            return mm->add_instruction(make_op("gather", {{"axis", 0}}), {l0, args[0]});
Khalique's avatar
Khalique committed
689
690
691
692
        }
        MIGRAPHX_THROW("MIGraphX does not support axis != -1");
    }

Khalique's avatar
Khalique committed
693
694
    instruction_ref parse_pack(const std::string&,
                               const attribute_map& attributes,
695
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
696
697
698
699
700
701
    {
        // reinterpret as unsqueeze with concat
        std::vector<instruction_ref> unsqueezed_args;
        int64_t axis = 0;
        if(contains(attributes, "axis"))
            axis = attributes.at("axis").i();
702
703
704
        size_t input_size = args.front()->get_shape().lens().size();
        if(axis > input_size)
        {
Khalique's avatar
Khalique committed
705
706
            MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
                           " must be smaller than input size " + to_string(input_size));
707
708
        }

Khalique's avatar
Khalique committed
709
710
711
712
        std::transform(
            args.begin(),
            args.end(),
            std::back_inserter(unsqueezed_args),
713
714
715
716
            [&](instruction_ref arg) {
                return mm->add_instruction(make_op("unsqueeze", {{"axes", {axis}}}), arg);
            });
        return to_nhwc(mm->add_instruction(make_op("concat", {{"axis", axis}}), unsqueezed_args));
Khalique's avatar
Khalique committed
717
718
    }

Khalique's avatar
Khalique committed
719
    instruction_ref
720
    parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
721
722
723
    {
        size_t ndims = args.front()->get_shape().lens().size();

Khalique's avatar
Khalique committed
724
725
        // in tf, the paddings are arranged as a 2d shape (ndims, 2),
        // the last dim contains the left padding and right padding respectively
Khalique's avatar
Khalique committed
726
        std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
Paul's avatar
Paul committed
727
        auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
Khalique's avatar
Khalique committed
728
        for(size_t i = 0; i < 2 * ndims; i += 2)
Khalique's avatar
Khalique committed
729
        {
Khalique's avatar
Khalique committed
730
731
            pad_per_dim[i / 2].first  = tf_padding[i];
            pad_per_dim[i / 2].second = tf_padding[i + 1];
Khalique's avatar
Khalique committed
732
733
734
735
        }
        reorder_data(pad_per_dim);

        op::pad op;
Khalique's avatar
Khalique committed
736
737
        std::vector<int64_t> pads(ndims * 2);
        for(size_t i = 0; i < ndims; i++)
Khalique's avatar
Khalique committed
738
        {
Khalique's avatar
Khalique committed
739
740
            pads[i]         = pad_per_dim[i].first;
            pads[i + ndims] = pad_per_dim[i].second;
Khalique's avatar
Khalique committed
741
742
        }
        op.pads = pads;
743
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
744
745
    }

746
747
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
748
                                  std::vector<instruction_ref> args) const
749
750
    {
        op::pooling op{starts_with(name, "Max") ? "max" : "average"};
Khalique's avatar
Khalique committed
751

752
753
        if(contains(attributes, "strides"))
        {
754
            std::vector<size_t> stride;
755
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
756
            reorder_data(stride);
757
758
759
760
            if(stride.size() != 4)
            {
                MIGRAPHX_THROW("strides should have 4 values");
            }
761
762
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
763
764
765
        }
        if(contains(attributes, "ksize"))
        {
766
            std::vector<size_t> ksize;
767
            copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
768
            reorder_data(ksize);
769
770
771
            if(ksize.size() != 4)
            {
                MIGRAPHX_THROW("ksize should have 4 values");
Khalique's avatar
Khalique committed
772
            }
773
774
            op.lengths[0] = ksize[2];
            op.lengths[1] = ksize[3];
775
        }
Khalique's avatar
Khalique committed
776
777

        auto l0 = args[0];
Khalique's avatar
Khalique committed
778
779
780
781
782
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
783
784
                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
785
786
                calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
                calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
Khalique's avatar
Khalique committed
787
788
789
790

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
791
                    l0                           = mm->add_instruction(
792
793
794
795
                        migraphx::make_op(
                            "pad",
                            {{"pads", padding}, {"value", std::numeric_limits<float>::lowest()}}),
                        l0);
Khalique's avatar
Khalique committed
796
797
798
                }
                else
                {
Khalique's avatar
Khalique committed
799
800
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
801
                }
Khalique's avatar
Khalique committed
802
803
            }
        }
804
        return mm->add_instruction(op, l0);
805
    }
Khalique's avatar
Khalique committed
806

kahmed10's avatar
kahmed10 committed
807
    instruction_ref
808
    parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
kahmed10's avatar
kahmed10 committed
809
810
    {
        auto input_lens = args[0]->get_shape().lens();
811
812
        auto min_val    = mm->add_literal(0.0f);
        auto max_val    = mm->add_literal(6.0f);
kahmed10's avatar
kahmed10 committed
813

814
815
816
817
818
        min_val =
            mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), min_val);
        max_val =
            mm->add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}), max_val);
        return mm->add_instruction(make_op("clip"), args.front(), min_val, max_val);
kahmed10's avatar
kahmed10 committed
819
820
    }

821
    instruction_ref
822
    parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
823
824
825
826
    {
        op::reshape op;
        if(args.size() != 2)
            MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
Khalique's avatar
Khalique committed
827
        auto s = args[1]->eval();
828
        s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
829
        return mm->add_instruction(op, make_contiguous(args[0]));
830
831
    }

832
833
834
    // Use a literal instruction to replace the shape since output of
    // shape operator are literals in migraphx
    instruction_ref
835
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
836
    {
837
838
839
840
841
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int32_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
        std::transform(
            arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
842
        return mm->add_literal(migraphx::literal{s, vec_shape});
Khalique's avatar
Khalique committed
843
844
    }

845
    instruction_ref
846
    parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
847
    {
Khalique's avatar
Khalique committed
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
        op::slice op;
        auto starts     = args[1]->eval().get<int32_t>().to_vector();
        auto size       = args[2]->eval().get<int32_t>().to_vector();
        auto axes       = args[0]->get_shape().lens();
        size_t num_axes = axes.size();

        op.starts = std::vector<int64_t>(starts.begin(), starts.end());
        op.ends   = std::vector<int64_t>(num_axes);
        op.axes   = std::vector<int64_t>(num_axes);
        std::iota(op.axes.begin(), op.axes.end(), 0);
        for(size_t i = 0; i < num_axes; i++)
        {
            if(size[i] == -1)
                op.ends[i] = axes[i];
            else
                op.ends[i] = starts[i] + size[i];
        }
865
        return mm->add_instruction(op, make_contiguous(args[0]));
Khalique's avatar
Khalique committed
866
867
    }

Khalique's avatar
Khalique committed
868
869
870
871
872
    // template to facilitate the logsoftmax later
    template <class Op>
    instruction_ref parse_softmax(const std::string&,
                                  const attribute_map& attributes,
                                  std::vector<instruction_ref> args)
873
    {
Khalique's avatar
Khalique committed
874
        int axis      = -1;
Khalique's avatar
Khalique committed
875
876
877
878
879
        auto num_dims = args[0]->get_shape().lens().size();
        if(contains(attributes, "axis"))
        {
            axis = static_cast<int>(attributes.at("axis").i());
        }
880
881

        axis = tune_axis(num_dims, axis, "tf_parse_softmax");
Khalique's avatar
Khalique committed
882

883
        return mm->add_instruction(Op{axis}, make_contiguous(args[0]));
884
885
    }

kahmed10's avatar
kahmed10 committed
886
887
    std::vector<instruction_ref> parse_split(const std::string&,
                                             const attribute_map& attributes,
888
                                             std::vector<instruction_ref> args) const
kahmed10's avatar
kahmed10 committed
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
    {
        bool vector_as_input = args.size() == 3;
        int num_outputs      = 1;
        auto axis_arg        = args[0];
        auto input_arg       = args[1];
        if(vector_as_input)
        {
            input_arg = args[0];
            axis_arg  = args[2];
        }

        if(contains(attributes, "num_split"))
            num_outputs = attributes.at("num_split").i();

        std::vector<int> splits(num_outputs);
        std::vector<int> slice_pos{0};
        if(vector_as_input)
        {
            splits      = args[1]->eval().get<int32_t>().to_vector();
            num_outputs = splits.size();
        }

        assert(num_outputs > 0);

        if(num_outputs == 1)
914
915
            return std::vector<instruction_ref>{
                mm->add_instruction(make_op("identity"), input_arg)};
kahmed10's avatar
kahmed10 committed
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959

        auto lens     = input_arg->get_shape().lens();
        auto num_dims = lens.size();
        int axis      = axis_arg->eval().at<int32_t>();

        // ensure split is made evenly if "num_split" is used
        assert(vector_as_input or lens[axis] % num_outputs == 0);

        auto split_size = lens[axis] / num_outputs;

        // push back first end point of slice
        if(vector_as_input)
        {
            slice_pos.push_back(splits[0]);
        }
        else
        {
            slice_pos.push_back(split_size);
        }

        // calculate remaining end points for each slice
        for(auto i = 1; i < num_outputs; i++)
        {
            if(vector_as_input)
            {
                splits[i] += splits[i - 1];
                slice_pos.push_back(splits[i]);
            }
            else
            {
                slice_pos.push_back((i + 1) * split_size);
            }
        }
        std::vector<instruction_ref> result;
        for(auto i = 0; i < num_outputs; i++)
        {
            op::slice op;
            op.axes = std::vector<int64_t>(num_dims);
            std::iota(op.axes.begin(), op.axes.end(), 0);
            op.starts = std::vector<int64_t>(num_dims, 0);
            op.ends   = std::vector<int64_t>(lens.begin(), lens.end());

            op.starts[axis] = slice_pos[i];
            op.ends[axis]   = slice_pos[i + 1];
960
            result.push_back(mm->add_instruction(op, input_arg));
kahmed10's avatar
kahmed10 committed
961
962
963
964
        }
        return result;
    }

Khalique's avatar
Khalique committed
965
966
    instruction_ref parse_squeeze(const std::string&,
                                  const attribute_map& attributes,
967
                                  std::vector<instruction_ref> args) const
968
969
    {
        op::squeeze op;
Khalique's avatar
Khalique committed
970
        auto input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
971
        auto axes       = attributes.at("squeeze_dims").list().i();
972
        copy(axes, std::back_inserter(op.axes));
Khalique's avatar
Khalique committed
973

974
975
        if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
        {
Khalique's avatar
Khalique committed
976
            for(size_t i = 0; i < input_dims.size(); i++)
977
            {
Khalique's avatar
Khalique committed
978
                if(input_dims.at(i) == 1)
979
980
981
982
                {
                    op.axes.push_back(i);
                }
            }
983
        }
984
        return mm->add_instruction(op, make_contiguous(args[0]));
985
986
    }

Khalique's avatar
Khalique committed
987
988
989
    instruction_ref parse_stridedslice(const std::string&,
                                       const attribute_map& attributes,
                                       std::vector<instruction_ref> args)
990
991
    {
        op::slice op;
Khalique's avatar
Khalique committed
992
993
994
995
        auto starts              = args[1]->eval().get<int32_t>().to_vector();
        auto ends                = args[2]->eval().get<int32_t>().to_vector();
        auto l0                  = args[0];
        size_t num_axes          = l0->get_shape().lens().size();
996
        std::vector<size_t> axes = l0->get_shape().lens();
997

Khalique's avatar
Khalique committed
998
999
1000
1001
        op.starts = std::vector<int64_t>(starts.begin(), starts.end());
        op.ends   = std::vector<int64_t>(ends.begin(), ends.end());
        op.axes   = std::vector<int64_t>(num_axes);
        std::iota(op.axes.begin(), op.axes.end(), 0);
Khalique's avatar
Khalique committed
1002
1003
        uint32_t begin_mask       = 0;
        uint32_t end_mask         = 0;
1004
        uint32_t shrink_axis_mask = 0;
Khalique's avatar
Khalique committed
1005
        uint32_t bitwise_compare  = 1;
1006
1007
        std::vector<int64_t> squeeze_axes;

Khalique's avatar
Khalique committed
1008
1009
1010
1011
1012
1013
        if(contains(attributes, "begin_mask"))
            begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());

        if(contains(attributes, "end_mask"))
            end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());

1014
        if(contains(attributes, "shrink_axis_mask"))
1015
            shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
1016

Khalique's avatar
Khalique committed
1017
        std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
Khalique's avatar
Khalique committed
1018
        std::vector<int64_t> end_axes   = get_axes_from_mask(num_axes, end_mask);
Khalique's avatar
Khalique committed
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031

        for(size_t i = 0; i < num_axes; i++)
        {
            if(begin_axes.at(i) == 1)
            {
                op.starts.at(i) = 0;
            }
            if(end_axes.at(i) == 1)
            {
                op.ends.at(i) = axes.at(i);
            }
        }

1032
        auto l1 = mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
1033
        if(shrink_axis_mask == 0)
1034
            return l1;
Khalique's avatar
Khalique committed
1035

Khalique's avatar
Khalique committed
1036
        for(size_t i = 0; i < num_axes; i++)
1037
        {
1038
            // the LSB corresponds to axis 0 when determining which axes to squeeze
Khalique's avatar
Khalique committed
1039
            if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
1040
1041
                squeeze_axes.push_back(i);
        }
Khalique's avatar
Khalique committed
1042

1043
        return mm->add_instruction(make_op("squeeze", {{"axes", squeeze_axes}}), l1);
1044
1045
    }

1046
1047
1048
    instruction_ref parse_transpose(const std::string&,
                                    const attribute_map&,
                                    std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
1049
1050
1051
1052
1053
    {
        auto perm = args[1]->eval().get<int32_t>().to_vector();
        op::transpose op;
        op.dims = std::vector<int64_t>(perm.begin(), perm.end());

1054
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
1055
1056
    }

Khalique's avatar
Khalique committed
1057
1058
1059
1060
1061
    void parse_graph(const tensorflow::GraphDef& graph)
    {
        nodes = get_nodes(graph, input_nodes);
        for(auto&& input : input_nodes)
        {
Khalique's avatar
Khalique committed
1062
            const std::string& name   = input.name();
Khalique's avatar
Khalique committed
1063
            attribute_map input_attrs = get_attributes(input);
Khalique's avatar
Khalique committed
1064
1065
            shape::type_t shape_type  = parse_type(input_attrs.at("dtype").type());
            std::vector<size_t> dims  = parse_dims(input_attrs.at("shape").shape());
Shucai Xiao's avatar
Shucai Xiao committed
1066
1067

            if(contains(map_input_dims, name))
1068
            {
Shucai Xiao's avatar
Shucai Xiao committed
1069
                dims = map_input_dims.at(name);
1070
            }
Shucai Xiao's avatar
Shucai Xiao committed
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
            else
            {
                if(is_nhwc and dims.size() >= 4)
                {
                    reorder_data(dims);
                }
                std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
                    return static_cast<int>(dim) <= 0 ? batch_size : dim;
                });
            }

Khalique's avatar
Khalique committed
1082
            shape s            = shape{shape_type, dims};
1083
            instructions[name] = to_nhwc(mm->add_parameter(name, s));
Khalique's avatar
Khalique committed
1084
1085
1086
        }
        for(auto&& p : nodes)
        {
1087
            this->parse_node(p.first);
Khalique's avatar
Khalique committed
1088
        }
1089
1090
1091

        // Needs to add a ret instruction at the end of
        // the program
Khalique's avatar
Khalique committed
1092
1093
1094
1095
1096
1097
1098
    }

    void parse_node(const std::string& name)
    {
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
Khalique's avatar
Khalique committed
1099
1100
1101
            // assert ops ignored
            if(node.op() == "Assert" or contains(name, "Assert"))
                return;
kahmed10's avatar
kahmed10 committed
1102
1103
1104
            // noOps ignored
            if(node.op() == "NoOp" or contains(name, "NoOp"))
                return;
Khalique's avatar
Khalique committed
1105
1106
1107
1108
            std::vector<instruction_ref> args;

            for(auto&& input : node.input())
            {
Khalique's avatar
Khalique committed
1109
1110
1111
                // control dependencies (signified by ^ before the name) are ignored
                if(contains(input, "^"))
                    continue;
Khalique's avatar
Khalique committed
1112
1113
                if(nodes.count(input) > 0)
                {
kahmed10's avatar
kahmed10 committed
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
                    std::string iname;
                    // input was from a node with multiple outputs
                    if(contains(input, ':'))
                    {
                        iname = input.substr(0, input.find(':'));
                    }
                    else
                    {
                        iname = get_name(nodes.at(input));
                    }
Khalique's avatar
Khalique committed
1124
1125
                    assert(name != iname);
                    this->parse_node(iname);
kahmed10's avatar
kahmed10 committed
1126
                    args.push_back(instructions.at(input));
Khalique's avatar
Khalique committed
1127
1128
1129
1130
1131
1132
                }
                else
                {
                    args.push_back(instructions.at(input));
                }
            }
kahmed10's avatar
kahmed10 committed
1133
1134

            std::vector<instruction_ref> result;
Khalique's avatar
Khalique committed
1135
1136
            if(ops.count(node.op()) == 0)
            {
1137
                result.push_back(mm->add_instruction(op::unknown{node.op()}, args));
Khalique's avatar
Khalique committed
1138
1139
1140
            }
            else
            {
kahmed10's avatar
kahmed10 committed
1141
1142
1143
1144
1145
1146
1147
1148
1149
                result = ops[node.op()](get_attributes(node), args);
            }

            assert(!result.empty());
            // First output has no ":" delimiter
            instructions[name] = result.front();
            for(size_t i = 1; i < result.size(); i++)
            {
                instructions[name + ":" + std::to_string(i)] = result.at(i);
Khalique's avatar
Khalique committed
1150
1151
1152
1153
            }
        }
    }

1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
    void parse_from(std::istream& is)
    {
        tensorflow::GraphDef graph;
        if(graph.ParseFromIstream(&is))
        {
            this->parse_graph(graph);
        }
        else
        {
            throw std::runtime_error("Failed reading tf file");
        }
    }

Khalique's avatar
Khalique committed
1167
1168
1169
    static attribute_map get_attributes(const tensorflow::NodeDef& node)
    {
        attribute_map result;
Khalique's avatar
Khalique committed
1170
        for(auto&& attr : node.attr())
Khalique's avatar
Khalique committed
1171
1172
1173
1174
1175
1176
        {
            result[attr.first] = attr.second;
        }
        return result;
    }

Khalique's avatar
Khalique committed
1177
    static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
Khalique's avatar
Khalique committed
1178

Khalique's avatar
Khalique committed
1179
1180
    static node_map get_nodes(const tensorflow::GraphDef& graph,
                              std::vector<tensorflow::NodeDef>& input_nodes)
Khalique's avatar
Khalique committed
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    {
        node_map result;
        for(auto&& node : graph.node())
        {
            auto node_name = get_name(node);
            // assume each node in graph has an associated name
            if(node_name.empty())
                MIGRAPHX_THROW("tf node with no name found");
            result[node_name] = node;
            if(node.op() == "Placeholder")
            {
                input_nodes.push_back(node);
            }
        }
        return result;
    }

    static shape::type_t parse_type(const tensorflow::DataType t)
    {
        shape::type_t shape_type{};
        switch(t)
        {
        case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
        case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
        case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
        case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
        case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
Paul's avatar
Paul committed
1208
1209
1210
1211
        case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
        case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
        case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
        case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
Paul's avatar
Paul committed
1212
        case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
1213
1214
1215

        case tensorflow::DataType::DT_INVALID:
        case tensorflow::DataType::DT_UINT8:
Khalique's avatar
Khalique committed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
        case tensorflow::DataType::DT_STRING:
        case tensorflow::DataType::DT_COMPLEX64:
        case tensorflow::DataType::DT_BOOL:
        case tensorflow::DataType::DT_QINT8:
        case tensorflow::DataType::DT_QUINT8:
        case tensorflow::DataType::DT_QINT32:
        case tensorflow::DataType::DT_BFLOAT16:
        case tensorflow::DataType::DT_QINT16:
        case tensorflow::DataType::DT_QUINT16:
        case tensorflow::DataType::DT_COMPLEX128:
        case tensorflow::DataType::DT_RESOURCE:
        case tensorflow::DataType::DT_VARIANT:
Khalique's avatar
Khalique committed
1228
        // tf pb should not use these types
Paul's avatar
Paul committed
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
        case tensorflow::DataType::DT_FLOAT_REF:
        case tensorflow::DataType::DT_DOUBLE_REF:
        case tensorflow::DataType::DT_INT32_REF:
        case tensorflow::DataType::DT_UINT8_REF:
        case tensorflow::DataType::DT_INT16_REF:
        case tensorflow::DataType::DT_INT8_REF:
        case tensorflow::DataType::DT_STRING_REF:
        case tensorflow::DataType::DT_COMPLEX64_REF:
        case tensorflow::DataType::DT_INT64_REF:
        case tensorflow::DataType::DT_BOOL_REF:
        case tensorflow::DataType::DT_QINT8_REF:
        case tensorflow::DataType::DT_QUINT8_REF:
        case tensorflow::DataType::DT_QINT32_REF:
        case tensorflow::DataType::DT_BFLOAT16_REF:
        case tensorflow::DataType::DT_QINT16_REF:
        case tensorflow::DataType::DT_QUINT16_REF:
        case tensorflow::DataType::DT_UINT16_REF:
        case tensorflow::DataType::DT_COMPLEX128_REF:
        case tensorflow::DataType::DT_HALF_REF:
        case tensorflow::DataType::DT_RESOURCE_REF:
        case tensorflow::DataType::DT_VARIANT_REF:
        case tensorflow::DataType::DT_UINT32_REF:
        case tensorflow::DataType::DT_UINT64_REF:
Paul's avatar
Paul committed
1252
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
Khalique's avatar
Khalique committed
1253
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
Khalique's avatar
Khalique committed
1254
1255
1256
1257
        }
        return shape_type;
    }

Khalique's avatar
Khalique committed
1258
    static literal parse_tensor(const tensorflow::TensorProto& t)
Khalique's avatar
Khalique committed
1259
1260
    {
        std::vector<size_t> dims = parse_dims(t.tensor_shape());
1261
        size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
Khalique's avatar
Khalique committed
1262
1263
        if(!t.tensor_content().empty()) // has raw data
        {
Khalique's avatar
Khalique committed
1264
            const std::string& s = t.tensor_content();
Khalique's avatar
Khalique committed
1265
1266
            switch(t.dtype())
            {
Khalique's avatar
Khalique committed
1267
1268
            case tensorflow::DataType::DT_FLOAT:
                return literal{{shape::float_type, dims}, s.data()};
Paul's avatar
Paul committed
1269
            case tensorflow::DataType::DT_BOOL:
1270
            case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1271
1272
            case tensorflow::DataType::DT_UINT16:
            case tensorflow::DataType::DT_INT16:
1273
                return literal{{shape::int16_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1274
1275
1276
1277
            case tensorflow::DataType::DT_INT32:
                return literal{{shape::int32_type, dims}, s.data()};
            case tensorflow::DataType::DT_INT64:
                return literal{{shape::int64_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1278
            case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1279
1280
            case tensorflow::DataType::DT_DOUBLE:
                return literal{{shape::double_type, dims}, s.data()};
Paul's avatar
Paul committed
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
            case tensorflow::DataType::DT_INVALID:
            case tensorflow::DataType::DT_UINT8:
            case tensorflow::DataType::DT_STRING:
            case tensorflow::DataType::DT_UINT32:
            case tensorflow::DataType::DT_UINT64:
            case tensorflow::DataType::DT_COMPLEX64:
            case tensorflow::DataType::DT_COMPLEX128:
            case tensorflow::DataType::DT_QINT8:
            case tensorflow::DataType::DT_QUINT8:
            case tensorflow::DataType::DT_QINT32:
            case tensorflow::DataType::DT_BFLOAT16:
            case tensorflow::DataType::DT_QINT16:
            case tensorflow::DataType::DT_QUINT16:
            case tensorflow::DataType::DT_RESOURCE:
            case tensorflow::DataType::DT_VARIANT:
            case tensorflow::DataType::DT_FLOAT_REF:
            case tensorflow::DataType::DT_DOUBLE_REF:
            case tensorflow::DataType::DT_INT32_REF:
            case tensorflow::DataType::DT_UINT8_REF:
            case tensorflow::DataType::DT_INT16_REF:
            case tensorflow::DataType::DT_INT8_REF:
            case tensorflow::DataType::DT_STRING_REF:
            case tensorflow::DataType::DT_COMPLEX64_REF:
            case tensorflow::DataType::DT_INT64_REF:
            case tensorflow::DataType::DT_BOOL_REF:
            case tensorflow::DataType::DT_QINT8_REF:
            case tensorflow::DataType::DT_QUINT8_REF:
            case tensorflow::DataType::DT_QINT32_REF:
            case tensorflow::DataType::DT_BFLOAT16_REF:
            case tensorflow::DataType::DT_QINT16_REF:
            case tensorflow::DataType::DT_QUINT16_REF:
            case tensorflow::DataType::DT_UINT16_REF:
            case tensorflow::DataType::DT_COMPLEX128_REF:
            case tensorflow::DataType::DT_HALF_REF:
            case tensorflow::DataType::DT_RESOURCE_REF:
            case tensorflow::DataType::DT_VARIANT_REF:
            case tensorflow::DataType::DT_UINT32_REF:
            case tensorflow::DataType::DT_UINT64_REF:
Khalique's avatar
Khalique committed
1319
1320
1321
            case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
            case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
                throw std::runtime_error("");
Khalique's avatar
Khalique committed
1322
1323
1324
1325
1326
1327
            }
            MIGRAPHX_THROW("Invalid tensor type");
        }
        switch(t.dtype())
        {
        case tensorflow::DataType::DT_FLOAT:
Khalique's avatar
Khalique committed
1328
1329
            return create_literal(
                shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
Khalique's avatar
Khalique committed
1330
        case tensorflow::DataType::DT_INT8:
1331
            return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1332
        case tensorflow::DataType::DT_UINT16:
1333
            return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1334
        case tensorflow::DataType::DT_INT16:
1335
            return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1336
        case tensorflow::DataType::DT_INT32:
1337
            return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1338
        case tensorflow::DataType::DT_INT64:
Khalique's avatar
Khalique committed
1339
1340
            return create_literal(
                shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
Khalique's avatar
Khalique committed
1341
        case tensorflow::DataType::DT_BOOL:
1342
            return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
Khalique's avatar
Khalique committed
1343
        case tensorflow::DataType::DT_HALF:
Khalique's avatar
Khalique committed
1344
        {
1345
1346
            std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
            std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
Khalique's avatar
Khalique committed
1347
1348
1349
1350
1351
            std::vector<half> data_half;
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1352
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1353
        }
Khalique's avatar
Khalique committed
1354
        case tensorflow::DataType::DT_DOUBLE:
Khalique's avatar
Khalique committed
1355
            return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
Paul's avatar
Paul committed
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
        case tensorflow::DataType::DT_INVALID:
        case tensorflow::DataType::DT_UINT8:
        case tensorflow::DataType::DT_STRING:
        case tensorflow::DataType::DT_UINT32:
        case tensorflow::DataType::DT_UINT64:
        case tensorflow::DataType::DT_COMPLEX64:
        case tensorflow::DataType::DT_COMPLEX128:
        case tensorflow::DataType::DT_QINT8:
        case tensorflow::DataType::DT_QUINT8:
        case tensorflow::DataType::DT_QINT32:
        case tensorflow::DataType::DT_BFLOAT16:
        case tensorflow::DataType::DT_QINT16:
        case tensorflow::DataType::DT_QUINT16:
        case tensorflow::DataType::DT_RESOURCE:
        case tensorflow::DataType::DT_VARIANT:
        case tensorflow::DataType::DT_FLOAT_REF:
        case tensorflow::DataType::DT_DOUBLE_REF:
        case tensorflow::DataType::DT_INT32_REF:
        case tensorflow::DataType::DT_UINT8_REF:
        case tensorflow::DataType::DT_INT16_REF:
        case tensorflow::DataType::DT_INT8_REF:
        case tensorflow::DataType::DT_STRING_REF:
        case tensorflow::DataType::DT_COMPLEX64_REF:
        case tensorflow::DataType::DT_INT64_REF:
        case tensorflow::DataType::DT_BOOL_REF:
        case tensorflow::DataType::DT_QINT8_REF:
        case tensorflow::DataType::DT_QUINT8_REF:
        case tensorflow::DataType::DT_QINT32_REF:
        case tensorflow::DataType::DT_BFLOAT16_REF:
        case tensorflow::DataType::DT_QINT16_REF:
        case tensorflow::DataType::DT_QUINT16_REF:
        case tensorflow::DataType::DT_UINT16_REF:
        case tensorflow::DataType::DT_COMPLEX128_REF:
        case tensorflow::DataType::DT_HALF_REF:
        case tensorflow::DataType::DT_RESOURCE_REF:
        case tensorflow::DataType::DT_VARIANT_REF:
        case tensorflow::DataType::DT_UINT32_REF:
        case tensorflow::DataType::DT_UINT64_REF:
Khalique's avatar
Khalique committed
1394
1395
1396
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
            throw std::runtime_error("");
Khalique's avatar
Khalique committed
1397
1398
1399
1400
        }
        MIGRAPHX_THROW("Invalid tensor type");
    }

1401
    template <class T>
Khalique's avatar
Khalique committed
1402
    static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
1403
                                        const size_t& shape_size)
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
    {
        std::vector<T> data_vals(shape_size);
        // check if shape has enough data values given existing fields
        if(data.size() == 1)
        {
            std::fill(data_vals.begin(), data_vals.end(), data[0]);
        }
        else
            copy(data.begin(), data.end(), std::back_inserter(data_vals));
        return data_vals;
    }

Khalique's avatar
Khalique committed
1416
1417
1418
1419
    static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
    {
        std::vector<size_t> dims;
        auto input_dims = s.dim();
Khalique's avatar
Khalique committed
1420
1421
1422
        std::transform(input_dims.begin(),
                       input_dims.end(),
                       std::back_inserter(dims),
Paul's avatar
Paul committed
1423
                       [](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); });
Khalique's avatar
Khalique committed
1424
1425
        return dims;
    }
1426
1427

    template <class T>
Khalique's avatar
Khalique committed
1428
    static literal
1429
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
1430
    {
Khalique's avatar
Khalique committed
1431
        // assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
1432
        if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
1433
            return literal{{shape_type}, data};
1434
1435
        return literal{{shape_type, dims}, data};
    }
Khalique's avatar
Khalique committed
1436
1437
};

Shucai Xiao's avatar
Shucai Xiao committed
1438
program parse_tf(const std::string& name, const tf_options& options)
Khalique's avatar
Khalique committed
1439
1440
1441
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    tf_parser parser;
Shucai Xiao's avatar
Shucai Xiao committed
1442
1443
1444
    parser.is_nhwc        = options.is_nhwc;
    parser.batch_size     = options.batch_size;
    parser.map_input_dims = options.map_input_dims;
Khalique's avatar
Khalique committed
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459

#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
1460
    parser.to_nchw(std::prev(parser.mm->end()));
Khalique's avatar
Khalique committed
1461
1462
1463
1464
1465
    return std::move(parser.prog);
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx