tf.cpp 46 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <array>
#include <utility>
#include <vector>

#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
Khalique's avatar
Khalique committed
20
#include <migraphx/pad_calc.hpp>
Khalique's avatar
Khalique committed
21
22
23
24
25
26
27

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct tf_parser
{
    using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
Khalique's avatar
Khalique committed
28
    using node_map      = std::map<std::string, tensorflow::NodeDef>;
Khalique's avatar
Khalique committed
29
30
    // using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
    using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
Khalique's avatar
Khalique committed
31

Khalique's avatar
Khalique committed
32
33
34
35
36
37
38
39
    node_map nodes;
    std::vector<tensorflow::NodeDef> input_nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
    program prog = program();
    bool is_nhwc = true;

    std::unordered_map<std::string, op_func> ops;

Khalique's avatar
Khalique committed
40
41
    std::vector<size_t>
    parse_axes(const attribute_map& attributes, const std::string& s, const size_t& num_dims) const
42
    {
43
44
45
        auto attrs = attributes.at(s).list().i();
        std::vector<size_t> axes;
        copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
Khalique's avatar
Khalique committed
46
        if(is_nhwc)
47
        {
Khalique's avatar
Khalique committed
48
            std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
Khalique's avatar
Khalique committed
49
                return parse_axis(axis, num_dims);
Khalique's avatar
Khalique committed
50
            });
51
52
53
54
        }
        return axes;
    }

Khalique's avatar
Khalique committed
55
    template <class T>
Khalique's avatar
Khalique committed
56
    std::vector<T> parse_axes(std::vector<T> axes, const size_t& num_dims) const
Khalique's avatar
Khalique committed
57
58
59
    {
        if(is_nhwc)
        {
60
            std::vector<T> new_axes;
Khalique's avatar
Khalique committed
61
62
63
            std::transform(axes.begin(),
                           axes.end(),
                           std::back_inserter(new_axes),
Khalique's avatar
Khalique committed
64
                           [&](size_t axis) { return parse_axis(axis, num_dims); });
65
            return new_axes;
Khalique's avatar
Khalique committed
66
        }
67
        return axes;
Khalique's avatar
Khalique committed
68
69
    }

Khalique's avatar
Khalique committed
70
71
72
    // tf stores certain attributes such as strides, dilations, as a 4D input.
    // The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
    // This helper function reorders the data to store for the respective operator member variables.
73
    template <class T>
74
    void reorder_data(std::vector<T>& prev_data) const
75
76
    {
        std::vector<T> new_data(prev_data.size());
77
        for(size_t i = 0; i < new_data.size(); i++)
78
        {
Khalique's avatar
Khalique committed
79
            auto new_idx         = parse_axis(i, new_data.size());
80
            new_data.at(new_idx) = prev_data.at(i);
81
        }
82
83
84
85
        prev_data = new_data;
    }

    template <class T>
Khalique's avatar
Khalique committed
86
    T parse_axis(const T& dim, const size_t& num_dims) const
87
    {
Khalique's avatar
Khalique committed
88
        T new_dim = dim;
Khalique's avatar
Khalique committed
89
        if(is_nhwc and num_dims >= 4)
90
91
92
        {
            switch(dim)
            {
Khalique's avatar
Khalique committed
93
94
95
96
97
            case 0: new_dim = 0; break;
            case 1: new_dim = 2; break;
            case 2: new_dim = 3; break;
            case 3: new_dim = 1; break;
            default: break;
98
99
            }
        }
Khalique's avatar
Khalique committed
100
        return new_dim;
101
102
    }

103
104
105
106
107
108
109
    std::vector<int64_t> get_axes(size_t num_axes) const
    {
        std::vector<int64_t> axes(num_axes);
        std::iota(axes.begin(), axes.end(), 0);
        return axes;
    }

Khalique's avatar
Khalique committed
110
111
112
113
    tf_parser()
    {
        add_generic_op("Identity", op::identity{});
        add_generic_op("Relu", op::relu{});
Khalique's avatar
Khalique committed
114
        add_generic_op("Relu6", op::clip{6.0, 0.0});
Khalique's avatar
Khalique committed
115

116
        add_binary_op("Add", op::add{});
Khalique's avatar
Khalique committed
117
        add_binary_op("Mul", op::mul{});
Khalique's avatar
Khalique committed
118
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
119

120
        add_mem_op("AvgPool", &tf_parser::parse_pooling);
121
122
        add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
        add_mem_op("ConcatV2", &tf_parser::parse_concat);
Khalique's avatar
Khalique committed
123
124
        add_mem_op("Const", &tf_parser::parse_constant);
        add_mem_op("Conv2D", &tf_parser::parse_conv);
Khalique's avatar
Khalique committed
125
        add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
Khalique's avatar
Khalique committed
126
        add_mem_op("ExpandDims", &tf_parser::parse_expanddims);
Khalique's avatar
Khalique committed
127
        add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
128
        add_mem_op("MatMul", &tf_parser::parse_matmul);
129
        add_mem_op("MaxPool", &tf_parser::parse_pooling);
Khalique's avatar
Khalique committed
130
        add_mem_op("Mean", &tf_parser::parse_mean);
Khalique's avatar
Khalique committed
131
        add_mem_op("Pack", &tf_parser::parse_pack);
Khalique's avatar
Khalique committed
132
        add_mem_op("Pad", &tf_parser::parse_pad);
133
134
135
        add_mem_op("Reshape", &tf_parser::parse_reshape);
        add_mem_op("Softmax", &tf_parser::parse_softmax);
        add_mem_op("Squeeze", &tf_parser::parse_squeeze);
136
        add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
Khalique's avatar
Khalique committed
137
138
    }

139
140
141
142
143
144
145
146
147
148
149
150
151
    template <class F>
    void add_op(std::string name, F f)
    {
        ops.emplace(name, f);
    }

    // Multi output op
    template <class F>
    void add_multi_op(std::string name, F f)
    {
        ops.emplace(name, f);
    }

Khalique's avatar
Khalique committed
152
153
154
    template <class F>
    void add_mem_op(std::string name, F f)
    {
155
        add_op(name, [=](auto&&... xs) {
Khalique's avatar
Khalique committed
156
157
158
159
160
161
162
            return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
        });
    }

    template <class T>
    void add_binary_op(std::string name, T x)
    {
Paul's avatar
Paul committed
163
        add_op(name, [this, x](const attribute_map& attributes, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
164
165
            if(args.size() != 2)
                MIGRAPHX_THROW("binary operators should have 2 operands");
166
167
168
169
170
            auto l0 = args[1];
            if(contains(attributes, "data_format"))
            {
                if(is_nhwc)
                {
Khalique's avatar
Khalique committed
171
                    l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
172
173
174
                }
            }
            return add_broadcastable_binary_op(args[0], l0, x);
Khalique's avatar
Khalique committed
175
176
177
178
179
180
        });
    }

    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
181
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        {
            // Example:
            // s0 = (3,2,4,5) and s1 = (2,1,1)
            //
            // In this case we need to broadcast (:,1,1) portion of
            // s1 plus broadcast the 1st dimension of s1
            // giving output_lens = (3,2,4,5)
            //
            // Another example:
            // s0 = (3,2,1,5) and s1 = (2,7,5)
            // In this case we need to broadcast the (:,:,1:,:) axis
            // of s0 plus the 1st dimension of s1 giving
            // output_lens = (3,2,7,5)
            //
            // Get lengths for both arguments
197
198
            const std::vector<size_t>* s0 = &arg0->get_shape().lens();
            const std::vector<size_t>* s1 = &arg1->get_shape().lens();
Khalique's avatar
Khalique committed
199
200
201
202
203

            // Make sure s0 is the smaller size
            if(s0->size() > s1->size())
                std::swap(s0, s1);

204
            std::vector<size_t> output_lens(*s1);
Khalique's avatar
Khalique committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
            auto offset = s1->size() - s0->size();
            std::transform(s0->begin(),
                           s0->end(),
                           s1->begin() + offset,
                           output_lens.begin() + offset,
                           [](auto a, auto b) { return std::max(a, b); });

            auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
            auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
            return prog.add_instruction(x, l0, l1);
        }
        else
        {
            return prog.add_instruction(x, {arg0, arg1});
        }
    }

    template <class T>
    void add_generic_op(std::string name, T x)
    {
Paul's avatar
Paul committed
225
        add_op(name, [this, x](const attribute_map&, std::vector<instruction_ref> args) {
Khalique's avatar
Khalique committed
226
227
228
229
230
231
232
            return prog.add_instruction(x, args);
        });
    }

    instruction_ref
    parse_batchnorm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
Khalique's avatar
Khalique committed
233
234
235
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
Khalique's avatar
Khalique committed
236
237
238
239
240
241
242
243
        if(contains(attributes, "epsilon"))
        {
            epsilon = attributes.at("epsilon").f();
        }
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
        return prog.add_instruction(op, std::move(args));
    }

244
    instruction_ref
Khalique's avatar
Khalique committed
245
    parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
246
    {
247
        uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
Shucai Xiao's avatar
Shucai Xiao committed
248
        auto l0 = prog.add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
249
        return prog.add_instruction(op::add{}, args[0], l0);
250
251
    }

Khalique's avatar
Khalique committed
252
253
254
255
    instruction_ref
    parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
        // get index for axis within args
256
        size_t axis_idx = attributes.at("N").i();
Khalique's avatar
Khalique committed
257
258
        size_t axis =
            parse_axis(args[axis_idx]->eval().at<int64_t>(), args[0]->get_shape().lens().size());
Khalique's avatar
Khalique committed
259
        op::concat op{axis};
260
        // return only first N arguments (assuming last index is the axis value)
Khalique's avatar
Khalique committed
261
        return prog.add_instruction(
262
            op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
Khalique's avatar
Khalique committed
263
264
265
266
267
268
    }

    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
                                   const std::vector<instruction_ref>&)
    {
Khalique's avatar
Khalique committed
269
270
        literal v       = parse_tensor(attributes.at("value").tensor());
        auto l0         = prog.add_literal(v);
271
272
273
274
275
276
        size_t num_axes = l0->get_shape().lens().size();
        if(num_axes >= 4)
        {
            std::vector<int64_t> transpose_axes = get_axes(num_axes);
            reorder_data(transpose_axes);
            l0 = prog.add_instruction(op::transpose{transpose_axes}, l0);
Khalique's avatar
Khalique committed
277
        }
278
        return l0;
Khalique's avatar
Khalique committed
279
280
281
282
283
    }

    instruction_ref
    parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
    {
Khalique's avatar
Khalique committed
284
        op::convolution op;
Khalique's avatar
Khalique committed
285
286
        if(contains(attributes, "strides"))
        {
287
            std::vector<size_t> stride;
288
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
289
            reorder_data(stride);
290
291
            if(stride.size() != 4)
            {
292
                MIGRAPHX_THROW("strides should have 4 values");
293
            }
294
295
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
Khalique's avatar
Khalique committed
296
297
298
        }
        if(contains(attributes, "dilations"))
        {
299
            std::vector<size_t> dilation;
300
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
301
            reorder_data(dilation);
302
303
304
305
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
306
307
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
Khalique's avatar
Khalique committed
308
        }
Khalique's avatar
Khalique committed
309
        auto weights = args[1];
310
        // check if weights are from a constant
Khalique's avatar
Khalique committed
311
312

        if(weights->name() != "@param")
313
        {
Khalique's avatar
Khalique committed
314
315
316
317
318
319
320
321
            if(is_nhwc)
            {
                weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
            }
            else
            {
                weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
            }
322
        }
Khalique's avatar
Khalique committed
323

Khalique's avatar
Khalique committed
324
        auto l0 = args[0];
325
326
        if(contains(attributes, "padding"))
        {
Khalique's avatar
Khalique committed
327
            const std::string& pad_mode = attributes.at("padding").s();
328
329
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
330
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
331
332
333
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];
Khalique's avatar
Khalique committed
334
335

                auto input_dims = l0->get_shape().lens();
Khalique's avatar
Khalique committed
336
337
                size_t input_h  = input_dims[2];
                size_t input_w  = input_dims[3];
Khalique's avatar
Khalique committed
338
339
340
341
342
343
344
345
346
347
348
                std::vector<int64_t> pads(input_dims.size());
                calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w);

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
                    l0 = prog.add_instruction(migraphx::op::pad{padding}, l0);
                }
                else
                {
Khalique's avatar
Khalique committed
349
350
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
351
                }
352
353
354
            }
            else if(pad_mode.find("VALID") != std::string::npos)
            {
355
                op.padding_mode = op::padding_mode_t::valid;
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            }
            else if(pad_mode.find("EXPLICIT") != std::string::npos)
            {
                std::vector<size_t> padding;
                copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
                if(padding.size() != 4)
                {
                    MIGRAPHX_THROW("padding should have 4 values");
                }
                if(padding[0] != padding[2] || padding[1] != padding[3])
                {
                    MIGRAPHX_THROW("migraphx does not support asymetric padding");
                }
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }

Khalique's avatar
Khalique committed
374
        return prog.add_instruction(op, {l0, weights});
Khalique's avatar
Khalique committed
375
376
    }

Khalique's avatar
Khalique committed
377
378
379
    instruction_ref parse_depthwiseconv(const std::string&,
                                        attribute_map attributes,
                                        std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
380
381
382
    {
        op::convolution op;
        size_t num_channels = args[0]->get_shape().lens()[1];
Khalique's avatar
Khalique committed
383
        op.group            = num_channels;
Khalique's avatar
Khalique committed
384

Khalique's avatar
Khalique committed
385
386
387
388
389
390
391
392
393
394
395
396
        if(contains(attributes, "strides"))
        {
            std::vector<size_t> stride;
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
            reorder_data(stride);
            if(stride.size() != 4)
            {
                MIGRAPHX_THROW("strides should have 4 values");
            }
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
        }
Khalique's avatar
Khalique committed
397
398
399
400
401
402
403
404
405
406
407
408
409
        if(contains(attributes, "dilations"))
        {
            std::vector<size_t> dilation;
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
            reorder_data(dilation);
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
        }

Khalique's avatar
Khalique committed
410
411
412
413
414
415
416
417
418
419
420
        auto weights = args[1];
        // check if weights are from a constant
        if(weights->name() != "@param")
        {
            if(is_nhwc)
            {
                weights = prog.add_instruction(op::transpose{{1, 3, 0, 2}}, args[1]);
            }
            else
            {
                weights = prog.add_instruction(op::transpose{{3, 2, 0, 1}}, args[1]);
Khalique's avatar
Khalique committed
421
            }
Khalique's avatar
Khalique committed
422
        }
Khalique's avatar
Khalique committed
423

Khalique's avatar
Khalique committed
424
        auto l0 = args[0];
Khalique's avatar
Khalique committed
425
426
        if(contains(attributes, "padding"))
        {
Khalique's avatar
Khalique committed
427
428
            const std::string& pad_mode = attributes.at("padding").s();

Khalique's avatar
Khalique committed
429
430
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
431
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
432
433
434
435
436
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
Khalique's avatar
Khalique committed
437
438
                size_t input_h  = input_dims[2];
                size_t input_w  = input_dims[3];
Khalique's avatar
Khalique committed
439
440
441
442
443
444
445
446
447
448
449
                std::vector<int64_t> pads(input_dims.size());
                calculate_padding(0, pads, input_h, op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_w, op.stride[1], op.dilation[1], weight_w);

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
                    l0 = prog.add_instruction(migraphx::op::pad{padding}, l0);
                }
                else
                {
Khalique's avatar
Khalique committed
450
451
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
452
                }
Khalique's avatar
Khalique committed
453
454
455
456
457
458
            }
            else if(pad_mode.find("VALID") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::valid;
            }
        }
Khalique's avatar
Khalique committed
459

Khalique's avatar
Khalique committed
460
461
        std::vector<int64_t> new_weights_shape;
        copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
Khalique's avatar
Khalique committed
462
463
464
465

        // weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
        // out_channels is equal to the multiplier. Adjust by inserting a reshape and
        // setting in_channels to 1
Khalique's avatar
Khalique committed
466
        int64_t multiplier   = new_weights_shape[0];
Khalique's avatar
Khalique committed
467
468
469
        int64_t out_channels = num_channels * multiplier;
        new_weights_shape[0] = out_channels;
        new_weights_shape[1] = 1;
Paul's avatar
Paul committed
470
        // Make sure weights are contiguous before doing reshape
Paul's avatar
Paul committed
471
472
        auto cweights    = prog.add_instruction(op::contiguous{}, weights);
        auto new_weights = prog.add_instruction(op::reshape{new_weights_shape}, cweights);
Khalique's avatar
Khalique committed
473

Khalique's avatar
Khalique committed
474
        return prog.add_instruction(op, {l0, new_weights});
Khalique's avatar
Khalique committed
475
476
    }

Khalique's avatar
Khalique committed
477
478
    instruction_ref
    parse_expanddims(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
479
480
    {
        std::vector<size_t> input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
481
        std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
Khalique's avatar
Khalique committed
482
        size_t num_dims = input_dims.size();
Khalique's avatar
Khalique committed
483
484
485
        int32_t dim     = parse_axis(args[1]->eval().at<int32_t>(), num_dims);

        if(dim < 0)
Khalique's avatar
Khalique committed
486
487
488
489
490
491
492
493
494
495
        {
            new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
        }
        else
        {
            new_dims.insert(new_dims.begin() + dim, 1);
        }
        return prog.add_instruction(op::reshape{new_dims}, args[0]);
    }

Khalique's avatar
Khalique committed
496
497
    instruction_ref
    parse_matmul(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
498
499
500
    {
        bool transa = false;
        bool transb = false;
Khalique's avatar
Khalique committed
501

502
503
504
505
506
507
508
509
510
511
512
513
        if(contains(attributes, "transpose_a"))
        {
            transa = attributes.at("transpose_a").b();
        }
        if(contains(attributes, "transpose_b"))
        {
            transb = attributes.at("transpose_a").b();
        }

        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
Khalique's avatar
Khalique committed
514
        std::iter_swap(perm.end() - 1, perm.end() - 2);
515
516
517
518
519
520
521

        auto l1 = (transa) ? prog.add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];

        return prog.add_instruction(op::dot{}, l1, l2);
    }

Khalique's avatar
Khalique committed
522
523
    instruction_ref
    parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
524
525
    {
        bool keep_dims = attributes.at("keep_dims").b();
Khalique's avatar
Khalique committed
526
        std::vector<int32_t> hw_axes{2, 3};
Khalique's avatar
Khalique committed
527
        // check if conditions for GlobalAvgPool are met
Khalique's avatar
Khalique committed
528
        auto lens = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
529
530
        auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector(), lens.size());

Khalique's avatar
Khalique committed
531
        if(axes == hw_axes and lens.size() == 4)
Khalique's avatar
Khalique committed
532
533
        {
            op::pooling op{"average"};
Khalique's avatar
Khalique committed
534
535
            op.lengths[0] = lens[2];
            op.lengths[1] = lens[3];
Khalique's avatar
Khalique committed
536
537
538
539
540
            auto l0       = prog.add_instruction(op, args.front());
            if(keep_dims)
                return l0;
            return prog.add_instruction(
                op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0);
Khalique's avatar
Khalique committed
541
542
543
544
        }
        MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
    }

Khalique's avatar
Khalique committed
545
546
547
    instruction_ref parse_pack(const std::string&,
                               const attribute_map& attributes,
                               std::vector<instruction_ref> args)
Khalique's avatar
Khalique committed
548
549
550
551
552
553
    {
        // reinterpret as unsqueeze with concat
        std::vector<instruction_ref> unsqueezed_args;
        int64_t axis = 0;
        if(contains(attributes, "axis"))
            axis = attributes.at("axis").i();
554
555
556
        size_t input_size = args.front()->get_shape().lens().size();
        if(axis > input_size)
        {
Khalique's avatar
Khalique committed
557
558
            MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
                           " must be smaller than input size " + to_string(input_size));
559
560
        }
        // check if input arg needs axis to be converted to NCHW
Khalique's avatar
Khalique committed
561
        axis = parse_axis(axis, input_size);
562

Khalique's avatar
Khalique committed
563
564
565
566
567
        std::transform(
            args.begin(),
            args.end(),
            std::back_inserter(unsqueezed_args),
            [&](instruction_ref arg) { return prog.add_instruction(op::unsqueeze{{axis}}, arg); });
Khalique's avatar
Khalique committed
568
569
570
        return prog.add_instruction(op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
    }

Khalique's avatar
Khalique committed
571
572
573
574
575
    instruction_ref
    parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
    {
        size_t ndims = args.front()->get_shape().lens().size();

Khalique's avatar
Khalique committed
576
577
        // in tf, the paddings are arranged as a 2d shape (ndims, 2),
        // the last dim contains the left padding and right padding respectively
Khalique's avatar
Khalique committed
578
579
        std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
        auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
Khalique's avatar
Khalique committed
580
        for(size_t i = 0; i < 2 * ndims; i += 2)
Khalique's avatar
Khalique committed
581
        {
Khalique's avatar
Khalique committed
582
583
            pad_per_dim[i / 2].first  = tf_padding[i];
            pad_per_dim[i / 2].second = tf_padding[i + 1];
Khalique's avatar
Khalique committed
584
585
586
587
        }
        reorder_data(pad_per_dim);

        op::pad op;
Khalique's avatar
Khalique committed
588
589
        std::vector<int64_t> pads(ndims * 2);
        for(size_t i = 0; i < ndims; i++)
Khalique's avatar
Khalique committed
590
        {
Khalique's avatar
Khalique committed
591
592
            pads[i]         = pad_per_dim[i].first;
            pads[i + ndims] = pad_per_dim[i].second;
Khalique's avatar
Khalique committed
593
594
595
596
597
        }
        op.pads = pads;
        return prog.add_instruction(op, args.front());
    }

598
599
600
601
602
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
                                  std::vector<instruction_ref> args)
    {
        op::pooling op{starts_with(name, "Max") ? "max" : "average"};
Khalique's avatar
Khalique committed
603

604
605
        if(contains(attributes, "strides"))
        {
606
            std::vector<size_t> stride;
607
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
608
            reorder_data(stride);
609
610
611
612
            if(stride.size() != 4)
            {
                MIGRAPHX_THROW("strides should have 4 values");
            }
613
614
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
615
616
617
        }
        if(contains(attributes, "ksize"))
        {
618
            std::vector<size_t> ksize;
619
            copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
620
            reorder_data(ksize);
621
622
623
            if(ksize.size() != 4)
            {
                MIGRAPHX_THROW("ksize should have 4 values");
Khalique's avatar
Khalique committed
624
            }
625
626
            op.lengths[0] = ksize[2];
            op.lengths[1] = ksize[3];
627
        }
Khalique's avatar
Khalique committed
628
629

        auto l0 = args[0];
Khalique's avatar
Khalique committed
630
631
632
633
634
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
635
                op.padding_mode = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
636
                auto input_dims = l0->get_shape().lens();
Khalique's avatar
Khalique committed
637
638
                size_t input_h  = input_dims[2];
                size_t input_w  = input_dims[3];
Khalique's avatar
Khalique committed
639
640
641
642
643
644
645
                std::vector<int64_t> pads(input_dims.size());
                calculate_padding(0, pads, input_h, op.stride[0], 1, op.lengths[0]);
                calculate_padding(1, pads, input_w, op.stride[1], 1, op.lengths[1]);

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
Khalique's avatar
Khalique committed
646
647
                    l0                           = prog.add_instruction(
                        migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
Khalique's avatar
Khalique committed
648
649
650
                }
                else
                {
Khalique's avatar
Khalique committed
651
652
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
653
                }
Khalique's avatar
Khalique committed
654
655
656
657
658
659
            }
            else if(pad_mode.find("VALID") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::valid;
            }
        }
Khalique's avatar
Khalique committed
660
        return prog.add_instruction(op, l0);
661
    }
Khalique's avatar
Khalique committed
662

663
    instruction_ref
Khalique's avatar
Khalique committed
664
    parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
665
666
667
668
    {
        op::reshape op;
        if(args.size() != 2)
            MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
Khalique's avatar
Khalique committed
669
        auto s = args[1]->eval();
670
671
672
673
        s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
        return prog.add_instruction(op, args[0]);
    }

Khalique's avatar
Khalique committed
674
675
676
677
678
679
680
681
682
    void parse_from(std::istream& is)
    {
        tensorflow::GraphDef graph;
        if(graph.ParseFromIstream(&is))
        {
            this->parse_graph(graph);
        }
        else
        {
683
            throw std::runtime_error("Failed reading tf file");
Khalique's avatar
Khalique committed
684
685
686
        }
    }

687
688
689
690
691
692
693
694
695
696
    instruction_ref
    parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
    {
        auto dims = args.front()->get_shape().lens();
        auto r =
            prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, args.front());
        auto s = prog.add_instruction(op::softmax{}, r);
        return prog.add_instruction(op::reshape{{long(dims[0]), long(dims[1])}}, s);
    }

Khalique's avatar
Khalique committed
697
698
699
    instruction_ref parse_squeeze(const std::string&,
                                  const attribute_map& attributes,
                                  std::vector<instruction_ref> args)
700
701
    {
        op::squeeze op;
Khalique's avatar
Khalique committed
702
        auto input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
703
        auto axes       = parse_axes(attributes, "squeeze_dims", input_dims.size());
704
        copy(axes, std::back_inserter(op.axes));
Khalique's avatar
Khalique committed
705

706
707
        if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
        {
Khalique's avatar
Khalique committed
708
            for(size_t i = 0; i < input_dims.size(); i++)
709
            {
Khalique's avatar
Khalique committed
710
                if(input_dims.at(i) == 1)
711
712
713
714
                {
                    op.axes.push_back(i);
                }
            }
715
        }
716
        return prog.add_instruction(op, args[0]);
717
718
    }

Khalique's avatar
Khalique committed
719
720
721
    instruction_ref parse_stridedslice(const std::string&,
                                       const attribute_map& attributes,
                                       std::vector<instruction_ref> args)
722
723
    {
        op::slice op;
Khalique's avatar
Khalique committed
724
725
726
727
728
729
730
731
        auto starts     = args[1]->eval().get<int32_t>().to_vector();
        auto ends       = args[2]->eval().get<int32_t>().to_vector();
        size_t num_axes = args[0]->get_shape().lens().size();
        if(num_axes >= 4)
        {
            reorder_data(starts);
            reorder_data(ends);
        }
732

Khalique's avatar
Khalique committed
733
734
735
736
        op.starts = std::vector<int64_t>(starts.begin(), starts.end());
        op.ends   = std::vector<int64_t>(ends.begin(), ends.end());
        op.axes   = std::vector<int64_t>(num_axes);
        std::iota(op.axes.begin(), op.axes.end(), 0);
737
        uint32_t shrink_axis_mask = 0;
Khalique's avatar
Khalique committed
738
        uint32_t bitwise_compare  = 1;
739
740
741
        std::vector<int64_t> squeeze_axes;

        if(contains(attributes, "shrink_axis_mask"))
742
            shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
743

Khalique's avatar
Khalique committed
744
        for(size_t i = 0; i < num_axes; i++)
745
        {
746
            // the LSB corresponds to axis 0 when determining which axes to squeeze
Khalique's avatar
Khalique committed
747
            if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
748
749
                squeeze_axes.push_back(i);
        }
Khalique's avatar
Khalique committed
750
        squeeze_axes = parse_axes(squeeze_axes, num_axes);
Khalique's avatar
Khalique committed
751

752
753
754
755
        auto l0 = prog.add_instruction(op, args[0]);
        return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
    }

Khalique's avatar
Khalique committed
756
757
758
759
760
    void parse_graph(const tensorflow::GraphDef& graph)
    {
        nodes = get_nodes(graph, input_nodes);
        for(auto&& input : input_nodes)
        {
Khalique's avatar
Khalique committed
761
            const std::string& name   = input.name();
Khalique's avatar
Khalique committed
762
            attribute_map input_attrs = get_attributes(input);
Khalique's avatar
Khalique committed
763
764
            shape::type_t shape_type  = parse_type(input_attrs.at("dtype").type());
            std::vector<size_t> dims  = parse_dims(input_attrs.at("shape").shape());
765
            if(is_nhwc and dims.size() >= 4)
766
            {
767
                reorder_data(dims);
768
            }
Khalique's avatar
Khalique committed
769
770
            shape s            = shape{shape_type, dims};
            instructions[name] = prog.add_parameter(name, s);
Khalique's avatar
Khalique committed
771
772
773
        }
        for(auto&& p : nodes)
        {
774
            this->parse_node(p.first);
Khalique's avatar
Khalique committed
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
        }
    }

    void parse_node(const std::string& name)
    {
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<instruction_ref> args;

            for(auto&& input : node.input())
            {
                if(nodes.count(input) > 0)
                {
                    auto&& iname = get_name(nodes.at(input));
                    assert(name != iname);
                    this->parse_node(iname);
                    args.push_back(instructions.at(iname));
                }
                else
                {
                    args.push_back(instructions.at(input));
                }
            }
            if(ops.count(node.op()) == 0)
            {
801
                instructions[name] = prog.add_instruction(op::unknown{node.op()}, args);
Khalique's avatar
Khalique committed
802
803
804
805
806
807
808
809
810
811
812
            }
            else
            {
                instructions[name] = ops[node.op()](get_attributes(node), args);
            }
        }
    }

    static attribute_map get_attributes(const tensorflow::NodeDef& node)
    {
        attribute_map result;
Khalique's avatar
Khalique committed
813
        for(auto&& attr : node.attr())
Khalique's avatar
Khalique committed
814
815
816
817
818
819
        {
            result[attr.first] = attr.second;
        }
        return result;
    }

Khalique's avatar
Khalique committed
820
    static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
Khalique's avatar
Khalique committed
821

Khalique's avatar
Khalique committed
822
823
    static node_map get_nodes(const tensorflow::GraphDef& graph,
                              std::vector<tensorflow::NodeDef>& input_nodes)
Khalique's avatar
Khalique committed
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
    {
        node_map result;
        for(auto&& node : graph.node())
        {
            auto node_name = get_name(node);
            // assume each node in graph has an associated name
            if(node_name.empty())
                MIGRAPHX_THROW("tf node with no name found");
            result[node_name] = node;
            if(node.op() == "Placeholder")
            {
                input_nodes.push_back(node);
            }
        }
        return result;
    }

    static shape::type_t parse_type(const tensorflow::DataType t)
    {
        shape::type_t shape_type{};
        switch(t)
        {
        case tensorflow::DataType::DT_INVALID:
            break; // throw std::runtime_error("Unsupported type UNDEFINED");
        case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
        case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
        case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
        case tensorflow::DataType::DT_UINT8:
            break; // throw std::runtime_error("Unsupported type UINT8");
        case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
        case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
        case tensorflow::DataType::DT_STRING:
            break; // throw std::runtime_error("Unsupported type STRING");
        case tensorflow::DataType::DT_COMPLEX64:
            break; // throw std::runtime_error("Unsupported type COMPLEX64");
        case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
        case tensorflow::DataType::DT_BOOL:
            break; // throw std::runtime_error("Unsupported type BOOL");
        case tensorflow::DataType::DT_QINT8:
            break; // throw std::runtime_error("Unsupported type QINT8");
        case tensorflow::DataType::DT_QUINT8:
            break; // throw std::runtime_error("Unsupported type QUINT8");
        case tensorflow::DataType::DT_QINT32:
            break; // throw std::runtime_error("Unsupported type QINT32");
        case tensorflow::DataType::DT_BFLOAT16:
            break; // throw std::runtime_error("Unsupported type BFLOAT16");
        case tensorflow::DataType::DT_QINT16:
            break; // throw std::runtime_error("Unsupported type QINT16");
        case tensorflow::DataType::DT_QUINT16:
            break; // throw std::runtime_error("Unsupported type QUINT16");
        case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
        case tensorflow::DataType::DT_COMPLEX128:
            break; // throw std::runtime_error("Unsupported type COMPLEX128");
        case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
        case tensorflow::DataType::DT_RESOURCE:
            break; // throw std::runtime_error("Unsupported type RESOURCE");
        case tensorflow::DataType::DT_VARIANT:
            break; // throw std::runtime_error("Unsupported type VARIANT");
        case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
Khalique's avatar
Khalique committed
883
884
885
        case tensorflow::DataType::DT_UINT64:
            shape_type = shape::uint64_type;
            break;
Khalique's avatar
Khalique committed
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912

        // tf pb should not use these types
        case tensorflow::DataType::DT_FLOAT_REF: break;
        case tensorflow::DataType::DT_DOUBLE_REF: break;
        case tensorflow::DataType::DT_INT32_REF: break;
        case tensorflow::DataType::DT_UINT8_REF: break;
        case tensorflow::DataType::DT_INT16_REF: break;
        case tensorflow::DataType::DT_INT8_REF: break;
        case tensorflow::DataType::DT_STRING_REF: break;
        case tensorflow::DataType::DT_COMPLEX64_REF: break;
        case tensorflow::DataType::DT_INT64_REF: break;
        case tensorflow::DataType::DT_BOOL_REF: break;
        case tensorflow::DataType::DT_QINT8_REF: break;
        case tensorflow::DataType::DT_QUINT8_REF: break;
        case tensorflow::DataType::DT_QINT32_REF: break;
        case tensorflow::DataType::DT_BFLOAT16_REF: break;
        case tensorflow::DataType::DT_QINT16_REF: break;
        case tensorflow::DataType::DT_QUINT16_REF: break;
        case tensorflow::DataType::DT_UINT16_REF: break;
        case tensorflow::DataType::DT_COMPLEX128_REF: break;
        case tensorflow::DataType::DT_HALF_REF: break;
        case tensorflow::DataType::DT_RESOURCE_REF: break;
        case tensorflow::DataType::DT_VARIANT_REF: break;
        case tensorflow::DataType::DT_UINT32_REF: break;
        case tensorflow::DataType::DT_UINT64_REF: break;
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: break;
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
Khalique's avatar
Khalique committed
913
914
915
916
        }
        return shape_type;
    }

Khalique's avatar
Khalique committed
917
    static literal parse_tensor(const tensorflow::TensorProto& t)
Khalique's avatar
Khalique committed
918
919
    {
        std::vector<size_t> dims = parse_dims(t.tensor_shape());
920
        size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
Khalique's avatar
Khalique committed
921
922
        if(!t.tensor_content().empty()) // has raw data
        {
Khalique's avatar
Khalique committed
923
            const std::string& s = t.tensor_content();
Khalique's avatar
Khalique committed
924
925
926
            switch(t.dtype())
            {
            case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
Khalique's avatar
Khalique committed
927
928
            case tensorflow::DataType::DT_FLOAT:
                return literal{{shape::float_type, dims}, s.data()};
Khalique's avatar
Khalique committed
929
            case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
930
            case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
Khalique's avatar
Khalique committed
931
            case tensorflow::DataType::DT_UINT16:
932
                return literal{{shape::uint16_type, dims}, s.data()};
Khalique's avatar
Khalique committed
933
            case tensorflow::DataType::DT_INT16:
934
                return literal{{shape::int16_type, dims}, s.data()};
Khalique's avatar
Khalique committed
935
936
937
938
            case tensorflow::DataType::DT_INT32:
                return literal{{shape::int32_type, dims}, s.data()};
            case tensorflow::DataType::DT_INT64:
                return literal{{shape::int64_type, dims}, s.data()};
Khalique's avatar
Khalique committed
939
            case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
940
            case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
Khalique's avatar
Khalique committed
941
            case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
Khalique's avatar
Khalique committed
942
943
            case tensorflow::DataType::DT_DOUBLE:
                return literal{{shape::double_type, dims}, s.data()};
Khalique's avatar
Khalique committed
944
945
946
947
            case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
            case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
            case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
            case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
Khalique's avatar
Khalique committed
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
            case tensorflow::DataType::DT_QINT8: throw std::runtime_error("");
            case tensorflow::DataType::DT_QUINT8: throw std::runtime_error("");
            case tensorflow::DataType::DT_QINT32: throw std::runtime_error("");
            case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error("");
            case tensorflow::DataType::DT_QINT16: throw std::runtime_error("");
            case tensorflow::DataType::DT_QUINT16: throw std::runtime_error("");
            case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error("");
            case tensorflow::DataType::DT_VARIANT: throw std::runtime_error("");
            case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error("");
            case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error("");
Khalique's avatar
Khalique committed
979
980
981
982
            case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
                throw std::runtime_error("");
            case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
                throw std::runtime_error("");
Khalique's avatar
Khalique committed
983
984
985
986
987
988
989
            }
            MIGRAPHX_THROW("Invalid tensor type");
        }
        switch(t.dtype())
        {
        case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
        case tensorflow::DataType::DT_FLOAT:
Khalique's avatar
Khalique committed
990
991
            return create_literal(
                shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
Khalique's avatar
Khalique committed
992
993
        case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
        case tensorflow::DataType::DT_INT8:
994
            return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
995
        case tensorflow::DataType::DT_UINT16:
996
            return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
997
        case tensorflow::DataType::DT_INT16:
998
            return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
999
        case tensorflow::DataType::DT_INT32:
1000
            return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1001
        case tensorflow::DataType::DT_INT64:
Khalique's avatar
Khalique committed
1002
1003
            return create_literal(
                shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
Khalique's avatar
Khalique committed
1004
1005
        case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
        case tensorflow::DataType::DT_BOOL:
1006
            return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
Khalique's avatar
Khalique committed
1007
        case tensorflow::DataType::DT_HALF:
Khalique's avatar
Khalique committed
1008
        {
1009
1010
            std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
            std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
Khalique's avatar
Khalique committed
1011
1012
1013
1014
1015
            std::vector<half> data_half;
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1016
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1017
        }
Khalique's avatar
Khalique committed
1018
        case tensorflow::DataType::DT_DOUBLE:
Khalique's avatar
Khalique committed
1019
            return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
Khalique's avatar
Khalique committed
1020
1021
1022
1023
        case tensorflow::DataType::DT_UINT32: throw std::runtime_error("");
        case tensorflow::DataType::DT_UINT64: throw std::runtime_error("");
        case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error("");
        case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error("");
Khalique's avatar
Khalique committed
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        case tensorflow::DataType::DT_QINT8: throw std::runtime_error("");
        case tensorflow::DataType::DT_QUINT8: throw std::runtime_error("");
        case tensorflow::DataType::DT_QINT32: throw std::runtime_error("");
        case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error("");
        case tensorflow::DataType::DT_QINT16: throw std::runtime_error("");
        case tensorflow::DataType::DT_QUINT16: throw std::runtime_error("");
        case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error("");
        case tensorflow::DataType::DT_VARIANT: throw std::runtime_error("");
        case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error("");
        case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error("");
Khalique's avatar
Khalique committed
1055
1056
1057
1058
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
            throw std::runtime_error("");
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
            throw std::runtime_error("");
Khalique's avatar
Khalique committed
1059
1060
1061
1062
        }
        MIGRAPHX_THROW("Invalid tensor type");
    }

1063
    template <class T>
Khalique's avatar
Khalique committed
1064
    static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
1065
                                        const size_t& shape_size)
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    {
        std::vector<T> data_vals(shape_size);
        // check if shape has enough data values given existing fields
        if(data.size() == 1)
        {
            std::fill(data_vals.begin(), data_vals.end(), data[0]);
        }
        else
            copy(data.begin(), data.end(), std::back_inserter(data_vals));
        return data_vals;
    }

Khalique's avatar
Khalique committed
1078
1079
1080
1081
    static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
    {
        std::vector<size_t> dims;
        auto input_dims = s.dim();
Khalique's avatar
Khalique committed
1082
1083
1084
        std::transform(input_dims.begin(),
                       input_dims.end(),
                       std::back_inserter(dims),
Paul's avatar
Paul committed
1085
                       [](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); });
Khalique's avatar
Khalique committed
1086
1087
        return dims;
    }
1088
1089

    template <class T>
Khalique's avatar
Khalique committed
1090
    static literal
1091
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
1092
    {
Khalique's avatar
Khalique committed
1093
        // assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
1094
        if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
1095
            return literal{{shape_type}, data};
1096
1097
        return literal{{shape_type, dims}, data};
    }
Khalique's avatar
Khalique committed
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
};

program parse_tf(const std::string& name, bool is_nhwc)
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    tf_parser parser;
    parser.is_nhwc = is_nhwc;

#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
    return std::move(parser.prog);
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx