tf.cpp 54.1 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <array>
#include <utility>
#include <vector>

#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
Khalique's avatar
Khalique committed
20
#include <migraphx/pad_calc.hpp>
Khalique's avatar
Khalique committed
21
22
23
24
25
26
27

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct tf_parser
{
    using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
Paul's avatar
Paul committed
28
    using node_map      = std::map<std::string, tensorflow::NodeDef>;
kahmed10's avatar
kahmed10 committed
29
30
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Khalique's avatar
Khalique committed
31

Khalique's avatar
Khalique committed
32
33
34
    node_map nodes;
    std::vector<tensorflow::NodeDef> input_nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
35
    program prog            = program();
36
    module* mm              = prog.get_main_module();
37
38
    bool is_nhwc            = true;
    unsigned int batch_size = 1;
Khalique's avatar
Khalique committed
39
40
41

    std::unordered_map<std::string, op_func> ops;

Paul's avatar
Paul committed
42
    bool should_transpose(instruction_ref ins) const
Paul's avatar
Paul committed
43
44
45
46
    {
        return is_nhwc and ins->get_shape().lens().size() == 4;
    }

47
    instruction_ref to_nhwc(instruction_ref ins) const
Paul's avatar
Paul committed
48
    {
Paul's avatar
Paul committed
49
        if(should_transpose(ins))
50
            return mm->add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
Paul's avatar
Paul committed
51
52
53
        return ins;
    }

54
    instruction_ref to_nchw(instruction_ref ins) const
Paul's avatar
Paul committed
55
    {
Paul's avatar
Paul committed
56
        if(should_transpose(ins))
57
            return mm->add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
Paul's avatar
Paul committed
58
59
60
        return ins;
    }

61
    instruction_ref to_kcxy(instruction_ref ins) const
Paul's avatar
Paul committed
62
    {
kahmed10's avatar
kahmed10 committed
63
        return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
Paul's avatar
Paul committed
64
65
    }

66
    instruction_ref make_contiguous(instruction_ref ins) const
Paul's avatar
Paul committed
67
    {
Paul's avatar
Paul committed
68
        if(ins->get_shape().standard())
Paul's avatar
Paul committed
69
70
            return ins;
        else
71
            return mm->add_instruction(op::contiguous{}, ins);
Paul's avatar
Paul committed
72
73
74
75
76
    }

    std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
    {
        std::vector<instruction_ref> result(args.size());
Paul's avatar
Paul committed
77
        std::transform(
Paul's avatar
Paul committed
78
            args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
Paul's avatar
Paul committed
79
80
81
        return result;
    }

kahmed10's avatar
kahmed10 committed
82
83
84
85
86
87
88
89
    std::vector<instruction_ref> to_nhwc(const std::vector<instruction_ref>& args)
    {
        std::vector<instruction_ref> result(args.size());
        std::transform(
            args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nhwc(ins); });
        return result;
    }

Khalique's avatar
Khalique committed
90
    std::vector<size_t>
91
    parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
92
    {
93
94
95
        auto attrs = attributes.at(s).list().i();
        std::vector<size_t> axes;
        copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
Khalique's avatar
Khalique committed
96
        if(is_nhwc)
97
        {
Khalique's avatar
Khalique committed
98
            std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
Khalique's avatar
Khalique committed
99
                return parse_axis(axis, num_dims);
Khalique's avatar
Khalique committed
100
            });
101
102
103
104
        }
        return axes;
    }

Khalique's avatar
Khalique committed
105
    template <class T>
106
    std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
Khalique's avatar
Khalique committed
107
108
109
    {
        if(is_nhwc)
        {
110
            std::vector<T> new_axes;
Khalique's avatar
Khalique committed
111
112
113
            std::transform(axes.begin(),
                           axes.end(),
                           std::back_inserter(new_axes),
Khalique's avatar
Khalique committed
114
                           [&](size_t axis) { return parse_axis(axis, num_dims); });
115
            return new_axes;
Khalique's avatar
Khalique committed
116
        }
117
        return axes;
Khalique's avatar
Khalique committed
118
119
    }

Khalique's avatar
Khalique committed
120
121
122
    // tf stores certain attributes such as strides, dilations, as a 4D input.
    // The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
    // This helper function reorders the data to store for the respective operator member variables.
123
    template <class T>
124
    void reorder_data(std::vector<T>& prev_data) const
125
126
    {
        std::vector<T> new_data(prev_data.size());
127
        for(size_t i = 0; i < new_data.size(); i++)
128
        {
Khalique's avatar
Khalique committed
129
            auto new_idx         = parse_axis(i, new_data.size());
130
            new_data.at(new_idx) = prev_data.at(i);
131
        }
132
133
134
135
        prev_data = new_data;
    }

    template <class T>
136
    T parse_axis(const T& dim, const size_t num_dims) const
137
    {
Khalique's avatar
Khalique committed
138
        T new_dim = dim;
Khalique's avatar
Khalique committed
139
        if(is_nhwc and num_dims >= 4)
140
141
142
        {
            switch(dim)
            {
Khalique's avatar
Khalique committed
143
144
145
146
147
            case 0: new_dim = 0; break;
            case 1: new_dim = 2; break;
            case 2: new_dim = 3; break;
            case 3: new_dim = 1; break;
            default: break;
148
149
            }
        }
Khalique's avatar
Khalique committed
150
        return new_dim;
151
152
    }

153
154
155
156
157
158
159
    std::vector<int64_t> get_axes(size_t num_axes) const
    {
        std::vector<int64_t> axes(num_axes);
        std::iota(axes.begin(), axes.end(), 0);
        return axes;
    }

Khalique's avatar
Khalique committed
160
    std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
Khalique's avatar
Khalique committed
161
    {
Khalique's avatar
Khalique committed
162
        uint32_t bitwise_compare = 1;
Khalique's avatar
Khalique committed
163
164
165
166
167
168
169
170
171
172
173
174
        std::vector<int64_t> axes;
        for(size_t i = 0; i < num_axes; i++)
        {
            // the LSB corresponds to axis 0 when determining which axes to begin
            if(((mask >> i) & bitwise_compare) == 1)
                axes.push_back(1);
            else
                axes.push_back(0);
        }
        return axes;
    }

Khalique's avatar
Khalique committed
175
176
    tf_parser()
    {
Khalique's avatar
Khalique committed
177
        add_generic_op("All", op::identity{});
Khalique's avatar
Khalique committed
178
        add_generic_op("Identity", op::identity{});
Khalique's avatar
Khalique committed
179
        add_generic_op("LessEqual", op::identity{});
Khalique's avatar
Khalique committed
180
        add_generic_op("Relu", op::relu{});
Khalique's avatar
Khalique committed
181
        add_generic_op("Rsqrt", op::rsqrt{});
Khalique's avatar
Khalique committed
182
        add_generic_op("Tanh", op::tanh{});
Khalique's avatar
Khalique committed
183
        add_generic_op("StopGradient", op::identity{});
Khalique's avatar
Khalique committed
184

185
        add_binary_op("Add", op::add{});
kahmed10's avatar
kahmed10 committed
186
        add_binary_op("AddV2", op::add{});
Khalique's avatar
Khalique committed
187
        add_binary_op("Mul", op::mul{});
Khalique's avatar
Khalique committed
188
        add_binary_op("Pow", op::pow{});
Khalique's avatar
Khalique committed
189
        add_binary_op("SquaredDifference", op::sqdiff{});
Khalique's avatar
Khalique committed
190
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
191

192
193
        add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false);
        add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false);
194
        add_mem_op("AvgPool", &tf_parser::parse_pooling);
Khalique's avatar
Khalique committed
195
        add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
Khalique's avatar
Khalique committed
196
        add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
197
        add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
Khalique's avatar
Khalique committed
198
        add_mem_op("Cast", &tf_parser::parse_cast, false);
Paul's avatar
Paul committed
199
        add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
Khalique's avatar
Khalique committed
200
        add_mem_op("Const", &tf_parser::parse_constant);
Paul's avatar
Paul committed
201
        add_mem_op("Conv2D", &tf_parser::parse_conv);
Paul's avatar
Paul committed
202
        add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
203
        add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
Khalique's avatar
Khalique committed
204
        add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
kahmed10's avatar
kahmed10 committed
205
        add_mem_op("FusedBatchNormV3", &tf_parser::parse_batchnorm);
Khalique's avatar
Khalique committed
206
        add_mem_op("GatherV2", &tf_parser::parse_gather, false);
Paul's avatar
Paul committed
207
        add_mem_op("MatMul", &tf_parser::parse_matmul, false);
208
        add_mem_op("MaxPool", &tf_parser::parse_pooling);
Khalique's avatar
Khalique committed
209
        add_mem_op("Mean", &tf_parser::parse_mean, false);
Khalique's avatar
Khalique committed
210
        add_mem_op("OneHot", &tf_parser::parse_onehot, false);
Paul's avatar
Paul committed
211
        add_mem_op("Pack", &tf_parser::parse_pack, false);
Paul's avatar
Paul committed
212
        add_mem_op("Pad", &tf_parser::parse_pad);
kahmed10's avatar
kahmed10 committed
213
        add_mem_op("Relu6", &tf_parser::parse_relu6);
Paul's avatar
Paul committed
214
        add_mem_op("Reshape", &tf_parser::parse_reshape, false);
215
        add_mem_op("Shape", &tf_parser::parse_shape, false);
Khalique's avatar
Khalique committed
216
        add_mem_op("Slice", &tf_parser::parse_slice, false);
kahmed10's avatar
kahmed10 committed
217
218
        add_mem_op("Split", &tf_parser::parse_split, false);
        add_mem_op("SplitV", &tf_parser::parse_split, false);
Khalique's avatar
Khalique committed
219
        add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
Paul's avatar
Paul committed
220
        add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
221
        add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
Khalique's avatar
Khalique committed
222
        add_mem_op("Transpose", &tf_parser::parse_transpose, false);
Khalique's avatar
Khalique committed
223
224
    }

225
    template <class F>
kahmed10's avatar
kahmed10 committed
226
    void add_op(const std::string& name, F f, bool transpose = true)
227
    {
Paul's avatar
Paul committed
228
        if(transpose)
Paul's avatar
Paul committed
229
        {
kahmed10's avatar
kahmed10 committed
230
231
232
233
234
235
            ops.emplace(
                name,
                op_func{
                    [=](const attribute_map& attributes, const std::vector<instruction_ref>& args) {
                        return std::vector<instruction_ref>{to_nhwc(f(attributes, to_nchw(args)))};
                    }});
Paul's avatar
Paul committed
236
237
238
        }
        else
        {
kahmed10's avatar
kahmed10 committed
239
240
241
242
243
            ops.emplace(name,
                        op_func{[=](const attribute_map& attributes,
                                    const std::vector<instruction_ref>& args) {
                            return std::vector<instruction_ref>{f(attributes, args)};
                        }});
Paul's avatar
Paul committed
244
        }
245
246
    }

Khalique's avatar
Khalique committed
247
    template <class F>
Paul's avatar
Paul committed
248
    void add_mem_op(std::string name, F f, bool transpose = true)
Khalique's avatar
Khalique committed
249
    {
Paul's avatar
Paul committed
250
251
252
253
254
        add_op(name,
               [=](auto&&... xs) {
                   return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
               },
               transpose);
Khalique's avatar
Khalique committed
255
256
257
258
259
    }

    template <class T>
    void add_binary_op(std::string name, T x)
    {
Paul's avatar
Paul committed
260
261
262
263
264
265
266
267
268
        add_op(name,
               [this, x](const attribute_map&, std::vector<instruction_ref> args) {
                   if(args.size() != 2)
                       MIGRAPHX_THROW("binary operators should have 2 operands");
                   // TODO
                   // if(contains(attributes, "data_format"))
                   // {
                   //     if(is_nhwc)
                   //     {
269
                   //         l0 = mm->add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
Paul's avatar
Paul committed
270
271
272
273
274
                   //     }
                   // }
                   return add_broadcastable_binary_op(args[0], args[1], x);
               },
               false);
Khalique's avatar
Khalique committed
275
276
277
278
279
    }

    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
280
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        {
            // Example:
            // s0 = (3,2,4,5) and s1 = (2,1,1)
            //
            // In this case we need to broadcast (:,1,1) portion of
            // s1 plus broadcast the 1st dimension of s1
            // giving output_lens = (3,2,4,5)
            //
            // Another example:
            // s0 = (3,2,1,5) and s1 = (2,7,5)
            // In this case we need to broadcast the (:,:,1:,:) axis
            // of s0 plus the 1st dimension of s1 giving
            // output_lens = (3,2,7,5)
            //
            // Get lengths for both arguments
296
297
            const std::vector<size_t>* s0 = &arg0->get_shape().lens();
            const std::vector<size_t>* s1 = &arg1->get_shape().lens();
Khalique's avatar
Khalique committed
298
299
300
301
302

            // Make sure s0 is the smaller size
            if(s0->size() > s1->size())
                std::swap(s0, s1);

303
            std::vector<size_t> output_lens(*s1);
Khalique's avatar
Khalique committed
304
305
306
307
308
309
310
            auto offset = s1->size() - s0->size();
            std::transform(s0->begin(),
                           s0->end(),
                           s1->begin() + offset,
                           output_lens.begin() + offset,
                           [](auto a, auto b) { return std::max(a, b); });

311
312
313
            auto l0 = mm->add_instruction(op::multibroadcast{output_lens}, arg0);
            auto l1 = mm->add_instruction(op::multibroadcast{output_lens}, arg1);
            return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
Khalique's avatar
Khalique committed
314
315
316
        }
        else
        {
317
            return to_nhwc(mm->add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
Khalique's avatar
Khalique committed
318
319
320
321
        }
    }

    template <class T>
Paul's avatar
Paul committed
322
    void add_generic_op(std::string name, T x, bool transpose = true)
Khalique's avatar
Khalique committed
323
    {
Paul's avatar
Paul committed
324
325
        add_op(name,
               [this, x](const attribute_map&, std::vector<instruction_ref> args) {
326
                   return mm->add_instruction(x, args);
Paul's avatar
Paul committed
327
328
               },
               transpose);
Khalique's avatar
Khalique committed
329
330
    }

331
332
333
334
335
336
    template <class Op>
    instruction_ref
    parse_arg_op(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        axis         = args[1]->eval().at<int64_t>();
337
338
        auto ins     = mm->add_instruction(Op{axis}, args.front());
        return mm->add_instruction(op::squeeze{{axis}}, ins);
339
340
    }

341
342
343
    instruction_ref parse_batchnorm(const std::string&,
                                    attribute_map attributes,
                                    std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
344
    {
Khalique's avatar
Khalique committed
345
346
347
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
Khalique's avatar
Khalique committed
348
349
350
351
352
        if(contains(attributes, "epsilon"))
        {
            epsilon = attributes.at("epsilon").f();
        }
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
353
        return mm->add_instruction(op, std::move(args));
Khalique's avatar
Khalique committed
354
355
    }

356
    instruction_ref
357
    parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
358
    {
359
        uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
360
361
        auto l0 = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
        return mm->add_instruction(op::add{}, args[0], l0);
362
363
    }

364
365
366
    instruction_ref parse_cast(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
367
368
    {
        shape::type_t type = parse_type(attributes.at("DstT").type());
369
        return mm->add_instruction(op::convert{type}, std::move(args));
Khalique's avatar
Khalique committed
370
371
    }

372
373
374
    instruction_ref parse_concat(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
375
376
    {
        // get index for axis within args
377
        size_t axis_idx = attributes.at("N").i();
Shucai Xiao's avatar
Shucai Xiao committed
378
        int64_t axis    = args[axis_idx]->eval().at<int64_t>();
Khalique's avatar
Khalique committed
379
        op::concat op{axis};
380
        // return only first N arguments (assuming last index is the axis value)
381
        return mm->add_instruction(
Paul's avatar
Paul committed
382
            op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
Khalique's avatar
Khalique committed
383
384
385
386
    }

    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
387
                                   const std::vector<instruction_ref>&) const
Khalique's avatar
Khalique committed
388
    {
Paul's avatar
Paul committed
389
        literal v = parse_tensor(attributes.at("value").tensor());
390
        return mm->add_literal(v);
Khalique's avatar
Khalique committed
391
392
    }

393
394
395
    instruction_ref parse_conv(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
396
397
398
399
    {
        op::convolution op;
        if(contains(attributes, "strides"))
        {
400
            std::vector<size_t> stride;
401
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
402
            reorder_data(stride);
403
404
            if(stride.size() != 4)
            {
405
                MIGRAPHX_THROW("strides should have 4 values");
406
            }
407
408
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
Khalique's avatar
Khalique committed
409
410
411
        }
        if(contains(attributes, "dilations"))
        {
412
            std::vector<size_t> dilation;
413
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
414
            reorder_data(dilation);
415
416
417
418
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
419
420
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
Khalique's avatar
Khalique committed
421
        }
Khalique's avatar
Khalique committed
422

Paul's avatar
Paul committed
423
        auto weights = to_kcxy(args[1]);
Paul's avatar
Paul committed
424
        auto l0      = args[0];
Khalique's avatar
Khalique committed
425
426
427
428
429
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
430
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
431
432
433
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];
Khalique's avatar
Khalique committed
434
435
436

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
437
438
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
Khalique's avatar
Khalique committed
439
440
441
442

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
443
                    l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
Khalique's avatar
Khalique committed
444
445
446
                }
                else
                {
Khalique's avatar
Khalique committed
447
448
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
449
                }
450
451
452
            }
            else if(pad_mode.find("VALID") != std::string::npos)
            {
453
                op.padding_mode = op::padding_mode_t::valid;
Khalique's avatar
Khalique committed
454
            }
Khalique's avatar
Khalique committed
455
            else if(pad_mode.find("EXPLICIT") != std::string::npos)
Khalique's avatar
Khalique committed
456
            {
457
                std::vector<size_t> padding;
458
                copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
Khalique's avatar
Khalique committed
459
460
461
462
463
464
465
466
467
468
469
470
                if(padding.size() != 4)
                {
                    MIGRAPHX_THROW("padding should have 4 values");
                }
                if(padding[0] != padding[2] || padding[1] != padding[3])
                {
                    MIGRAPHX_THROW("migraphx does not support asymetric padding");
                }
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
kahmed10's avatar
kahmed10 committed
471
        return mm->add_instruction(op, {l0, weights});
Khalique's avatar
Khalique committed
472
473
    }

Khalique's avatar
Khalique committed
474
475
    instruction_ref parse_depthwiseconv(const std::string&,
                                        attribute_map attributes,
476
                                        std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
477
478
479
    {
        op::convolution op;
        size_t num_channels = args[0]->get_shape().lens()[1];
Khalique's avatar
Khalique committed
480
        op.group            = num_channels;
Khalique's avatar
Khalique committed
481

Khalique's avatar
Khalique committed
482
483
        if(contains(attributes, "strides"))
        {
484
            std::vector<size_t> stride;
485
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
486
            reorder_data(stride);
487
488
            if(stride.size() != 4)
            {
489
                MIGRAPHX_THROW("strides should have 4 values");
490
            }
491
492
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
Khalique's avatar
Khalique committed
493
        }
Paul's avatar
Paul committed
494
495

        auto weights = to_kcxy(args[1]);
Khalique's avatar
Khalique committed
496
497
        if(contains(attributes, "dilations"))
        {
498
            std::vector<size_t> dilation;
499
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
500
            reorder_data(dilation);
501
502
503
504
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
505
506
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
Khalique's avatar
Khalique committed
507
508
        }

Khalique's avatar
Khalique committed
509
        auto l0 = args[0];
Khalique's avatar
Khalique committed
510
511
512
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
Khalique's avatar
Khalique committed
513

Khalique's avatar
Khalique committed
514
515
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
516
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
517
518
519
520
521
522
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
523
524
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
Khalique's avatar
Khalique committed
525
526
527
528

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
529
                    l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
Khalique's avatar
Khalique committed
530
531
532
                }
                else
                {
Khalique's avatar
Khalique committed
533
534
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
535
                }
Khalique's avatar
Khalique committed
536
            }
Khalique's avatar
Khalique committed
537
            else if(pad_mode.find("VALID") != std::string::npos)
Khalique's avatar
Khalique committed
538
            {
Khalique's avatar
Khalique committed
539
                op.padding_mode = op::padding_mode_t::valid;
Khalique's avatar
Khalique committed
540
541
            }
        }
Khalique's avatar
Khalique committed
542

Khalique's avatar
Khalique committed
543
544
        std::vector<int64_t> new_weights_shape;
        copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
Khalique's avatar
Khalique committed
545
546
547
548

        // weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
        // out_channels is equal to the multiplier. Adjust by inserting a reshape and
        // setting in_channels to 1
Khalique's avatar
Khalique committed
549
        int64_t multiplier   = new_weights_shape[0];
Khalique's avatar
Khalique committed
550
551
552
        int64_t out_channels = num_channels * multiplier;
        new_weights_shape[0] = out_channels;
        new_weights_shape[1] = 1;
Paul's avatar
Paul committed
553
        // Make sure weights are contiguous before doing reshape
Paul's avatar
Paul committed
554
        auto new_weights =
555
            mm->add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
Khalique's avatar
Khalique committed
556

557
        return mm->add_instruction(op, {l0, new_weights});
Khalique's avatar
Khalique committed
558
559
    }

560
561
562
    instruction_ref parse_expanddims(const std::string&,
                                     const attribute_map&,
                                     std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
563
564
    {
        std::vector<size_t> input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
565
        std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
Khalique's avatar
Khalique committed
566
        size_t num_dims = input_dims.size();
567
        int32_t dim     = args[1]->eval().at<int32_t>();
Khalique's avatar
Khalique committed
568
569

        if(dim < 0)
Khalique's avatar
Khalique committed
570
571
572
573
574
575
576
        {
            new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
        }
        else
        {
            new_dims.insert(new_dims.begin() + dim, 1);
        }
577
        return mm->add_instruction(op::reshape{new_dims}, args[0]);
Khalique's avatar
Khalique committed
578
579
    }

Khalique's avatar
Khalique committed
580
    instruction_ref
581
    parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
582
583
584
    {
        int axis = args[2]->eval().at<int32_t>();
        op::gather op{axis};
585
        return mm->add_instruction(op, {args[0], args[1]});
Khalique's avatar
Khalique committed
586
587
    }

588
589
590
    instruction_ref parse_matmul(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
591
592
593
    {
        bool transa = false;
        bool transb = false;
Khalique's avatar
Khalique committed
594

595
596
597
598
599
600
        if(contains(attributes, "transpose_a"))
        {
            transa = attributes.at("transpose_a").b();
        }
        if(contains(attributes, "transpose_b"))
        {
Khalique's avatar
Khalique committed
601
            transb = attributes.at("transpose_b").b();
602
603
        }

Khalique's avatar
Khalique committed
604
605
606
607
608
609
610
611
612
        if(contains(attributes, "adj_x"))
        {
            transa = attributes.at("adj_x").b();
        }
        if(contains(attributes, "adj_y"))
        {
            transb = attributes.at("adj_y").b();
        }

613
614
615
        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
Khalique's avatar
Khalique committed
616
        std::iter_swap(perm.end() - 1, perm.end() - 2);
617

618
619
        auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
620

621
        return mm->add_instruction(op::dot{}, l1, l2);
622
623
    }

624
625
626
    instruction_ref parse_mean(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
627
    {
Khalique's avatar
Khalique committed
628
629
        bool keep_dims = attributes.at("keep_dims").b();
        auto axes      = args[1]->eval().get<int32_t>().to_vector<int64_t>();
Khalique's avatar
Khalique committed
630
631

        if(keep_dims)
Khalique's avatar
Khalique committed
632
        {
633
            return mm->add_instruction(op::reduce_mean{axes}, args[0]);
634
635
636
        }
        else
        {
637
638
            auto ins = mm->add_instruction(op::reduce_mean{axes}, args[0]);
            return mm->add_instruction(op::squeeze{axes}, ins);
Khalique's avatar
Khalique committed
639
640
641
        }
    }

642
643
644
    instruction_ref parse_onehot(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
645
    {
Khalique's avatar
Khalique committed
646
647
        size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());

Khalique's avatar
Khalique committed
648
        int64_t axis    = -1;
Khalique's avatar
Khalique committed
649
650
        float on_value  = args[2]->eval().at<float>();
        float off_value = args[3]->eval().at<float>();
Khalique's avatar
Khalique committed
651

Khalique's avatar
Khalique committed
652
        std::vector<float> depth_input(depth * depth, off_value);
Khalique's avatar
Khalique committed
653
654
        for(int i = 0; i < depth; i++)
        {
Khalique's avatar
Khalique committed
655
            depth_input[depth * i + i] = on_value;
Khalique's avatar
Khalique committed
656
        }
Khalique's avatar
Khalique committed
657

Khalique's avatar
Khalique committed
658
        if(contains(attributes, "axis"))
Khalique's avatar
Khalique committed
659
660
661
            axis = attributes.at("axis").i();
        if(axis == -1)
        {
Khalique's avatar
Khalique committed
662
            shape s{shape::float_type, {depth, depth}};
663
664
            auto l0 = mm->add_literal({s, depth_input});
            return mm->add_instruction(op::gather{0}, {l0, args[0]});
Khalique's avatar
Khalique committed
665
666
667
668
        }
        MIGRAPHX_THROW("MIGraphX does not support axis != -1");
    }

Khalique's avatar
Khalique committed
669
670
    instruction_ref parse_pack(const std::string&,
                               const attribute_map& attributes,
671
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
672
673
674
675
676
677
    {
        // reinterpret as unsqueeze with concat
        std::vector<instruction_ref> unsqueezed_args;
        int64_t axis = 0;
        if(contains(attributes, "axis"))
            axis = attributes.at("axis").i();
678
679
680
        size_t input_size = args.front()->get_shape().lens().size();
        if(axis > input_size)
        {
Khalique's avatar
Khalique committed
681
682
            MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
                           " must be smaller than input size " + to_string(input_size));
683
684
        }

Khalique's avatar
Khalique committed
685
686
687
688
        std::transform(
            args.begin(),
            args.end(),
            std::back_inserter(unsqueezed_args),
689
690
            [&](instruction_ref arg) { return mm->add_instruction(op::unsqueeze{{axis}}, arg); });
        return to_nhwc(mm->add_instruction(op::concat{axis}, unsqueezed_args));
Khalique's avatar
Khalique committed
691
692
    }

Khalique's avatar
Khalique committed
693
    instruction_ref
694
    parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
695
696
697
    {
        size_t ndims = args.front()->get_shape().lens().size();

Khalique's avatar
Khalique committed
698
699
        // in tf, the paddings are arranged as a 2d shape (ndims, 2),
        // the last dim contains the left padding and right padding respectively
Khalique's avatar
Khalique committed
700
        std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
Paul's avatar
Paul committed
701
        auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
Khalique's avatar
Khalique committed
702
        for(size_t i = 0; i < 2 * ndims; i += 2)
Khalique's avatar
Khalique committed
703
        {
Khalique's avatar
Khalique committed
704
705
            pad_per_dim[i / 2].first  = tf_padding[i];
            pad_per_dim[i / 2].second = tf_padding[i + 1];
Khalique's avatar
Khalique committed
706
707
708
709
        }
        reorder_data(pad_per_dim);

        op::pad op;
Khalique's avatar
Khalique committed
710
711
        std::vector<int64_t> pads(ndims * 2);
        for(size_t i = 0; i < ndims; i++)
Khalique's avatar
Khalique committed
712
        {
Khalique's avatar
Khalique committed
713
714
            pads[i]         = pad_per_dim[i].first;
            pads[i + ndims] = pad_per_dim[i].second;
Khalique's avatar
Khalique committed
715
716
        }
        op.pads = pads;
717
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
718
719
    }

720
721
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
722
                                  std::vector<instruction_ref> args) const
723
724
    {
        op::pooling op{starts_with(name, "Max") ? "max" : "average"};
Khalique's avatar
Khalique committed
725

726
727
        if(contains(attributes, "strides"))
        {
728
            std::vector<size_t> stride;
729
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
730
            reorder_data(stride);
731
732
733
734
            if(stride.size() != 4)
            {
                MIGRAPHX_THROW("strides should have 4 values");
            }
735
736
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
737
738
739
        }
        if(contains(attributes, "ksize"))
        {
740
            std::vector<size_t> ksize;
741
            copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
742
            reorder_data(ksize);
743
744
745
            if(ksize.size() != 4)
            {
                MIGRAPHX_THROW("ksize should have 4 values");
Khalique's avatar
Khalique committed
746
            }
747
748
            op.lengths[0] = ksize[2];
            op.lengths[1] = ksize[3];
749
        }
Khalique's avatar
Khalique committed
750
751

        auto l0 = args[0];
Khalique's avatar
Khalique committed
752
753
754
755
756
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
757
758
                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
759
760
                calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
                calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
Khalique's avatar
Khalique committed
761
762
763
764

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
765
                    l0                           = mm->add_instruction(
Khalique's avatar
Khalique committed
766
                        migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
Khalique's avatar
Khalique committed
767
768
769
                }
                else
                {
Khalique's avatar
Khalique committed
770
771
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
772
                }
Khalique's avatar
Khalique committed
773
774
            }
        }
775
        return mm->add_instruction(op, l0);
776
    }
Khalique's avatar
Khalique committed
777

kahmed10's avatar
kahmed10 committed
778
    instruction_ref
779
    parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
kahmed10's avatar
kahmed10 committed
780
781
    {
        auto input_lens = args[0]->get_shape().lens();
782
783
        auto min_val    = mm->add_literal(0.0f);
        auto max_val    = mm->add_literal(6.0f);
kahmed10's avatar
kahmed10 committed
784

785
786
787
        min_val = mm->add_instruction(op::multibroadcast{input_lens}, min_val);
        max_val = mm->add_instruction(op::multibroadcast{input_lens}, max_val);
        return mm->add_instruction(op::clip{}, args.front(), min_val, max_val);
kahmed10's avatar
kahmed10 committed
788
789
    }

790
    instruction_ref
791
    parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
792
793
794
795
    {
        op::reshape op;
        if(args.size() != 2)
            MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
Khalique's avatar
Khalique committed
796
        auto s = args[1]->eval();
797
        s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
798
        return mm->add_instruction(op, make_contiguous(args[0]));
799
800
    }

801
802
803
    // Use a literal instruction to replace the shape since output of
    // shape operator are literals in migraphx
    instruction_ref
804
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
805
    {
806
807
808
809
810
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int32_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
        std::transform(
            arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
811
        return mm->add_literal(migraphx::literal{s, vec_shape});
Khalique's avatar
Khalique committed
812
813
    }

814
    instruction_ref
815
    parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
816
    {
Khalique's avatar
Khalique committed
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
        op::slice op;
        auto starts     = args[1]->eval().get<int32_t>().to_vector();
        auto size       = args[2]->eval().get<int32_t>().to_vector();
        auto axes       = args[0]->get_shape().lens();
        size_t num_axes = axes.size();

        op.starts = std::vector<int64_t>(starts.begin(), starts.end());
        op.ends   = std::vector<int64_t>(num_axes);
        op.axes   = std::vector<int64_t>(num_axes);
        std::iota(op.axes.begin(), op.axes.end(), 0);
        for(size_t i = 0; i < num_axes; i++)
        {
            if(size[i] == -1)
                op.ends[i] = axes[i];
            else
                op.ends[i] = starts[i] + size[i];
        }
834
        return mm->add_instruction(op, make_contiguous(args[0]));
Khalique's avatar
Khalique committed
835
836
    }

Khalique's avatar
Khalique committed
837
838
839
840
841
    // template to facilitate the logsoftmax later
    template <class Op>
    instruction_ref parse_softmax(const std::string&,
                                  const attribute_map& attributes,
                                  std::vector<instruction_ref> args)
842
    {
Khalique's avatar
Khalique committed
843
        int axis      = -1;
Khalique's avatar
Khalique committed
844
845
846
847
848
849
850
851
852
853
        auto num_dims = args[0]->get_shape().lens().size();
        if(contains(attributes, "axis"))
        {
            axis = static_cast<int>(attributes.at("axis").i());
        }
        if(axis < 0)
        {
            axis += num_dims;
        }

854
        return mm->add_instruction(Op{axis}, make_contiguous(args[0]));
855
856
    }

kahmed10's avatar
kahmed10 committed
857
858
    std::vector<instruction_ref> parse_split(const std::string&,
                                             const attribute_map& attributes,
859
                                             std::vector<instruction_ref> args) const
kahmed10's avatar
kahmed10 committed
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
    {
        bool vector_as_input = args.size() == 3;
        int num_outputs      = 1;
        auto axis_arg        = args[0];
        auto input_arg       = args[1];
        if(vector_as_input)
        {
            input_arg = args[0];
            axis_arg  = args[2];
        }

        if(contains(attributes, "num_split"))
            num_outputs = attributes.at("num_split").i();

        std::vector<int> splits(num_outputs);
        std::vector<int> slice_pos{0};
        if(vector_as_input)
        {
            splits      = args[1]->eval().get<int32_t>().to_vector();
            num_outputs = splits.size();
        }

        assert(num_outputs > 0);

        if(num_outputs == 1)
885
            return std::vector<instruction_ref>{mm->add_instruction(op::identity{}, input_arg)};
kahmed10's avatar
kahmed10 committed
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929

        auto lens     = input_arg->get_shape().lens();
        auto num_dims = lens.size();
        int axis      = axis_arg->eval().at<int32_t>();

        // ensure split is made evenly if "num_split" is used
        assert(vector_as_input or lens[axis] % num_outputs == 0);

        auto split_size = lens[axis] / num_outputs;

        // push back first end point of slice
        if(vector_as_input)
        {
            slice_pos.push_back(splits[0]);
        }
        else
        {
            slice_pos.push_back(split_size);
        }

        // calculate remaining end points for each slice
        for(auto i = 1; i < num_outputs; i++)
        {
            if(vector_as_input)
            {
                splits[i] += splits[i - 1];
                slice_pos.push_back(splits[i]);
            }
            else
            {
                slice_pos.push_back((i + 1) * split_size);
            }
        }
        std::vector<instruction_ref> result;
        for(auto i = 0; i < num_outputs; i++)
        {
            op::slice op;
            op.axes = std::vector<int64_t>(num_dims);
            std::iota(op.axes.begin(), op.axes.end(), 0);
            op.starts = std::vector<int64_t>(num_dims, 0);
            op.ends   = std::vector<int64_t>(lens.begin(), lens.end());

            op.starts[axis] = slice_pos[i];
            op.ends[axis]   = slice_pos[i + 1];
930
            result.push_back(mm->add_instruction(op, input_arg));
kahmed10's avatar
kahmed10 committed
931
932
933
934
        }
        return result;
    }

Khalique's avatar
Khalique committed
935
936
    instruction_ref parse_squeeze(const std::string&,
                                  const attribute_map& attributes,
937
                                  std::vector<instruction_ref> args) const
938
939
    {
        op::squeeze op;
Khalique's avatar
Khalique committed
940
        auto input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
941
        auto axes       = attributes.at("squeeze_dims").list().i();
942
        copy(axes, std::back_inserter(op.axes));
Khalique's avatar
Khalique committed
943

944
945
        if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
        {
Khalique's avatar
Khalique committed
946
            for(size_t i = 0; i < input_dims.size(); i++)
947
            {
Khalique's avatar
Khalique committed
948
                if(input_dims.at(i) == 1)
949
950
951
952
                {
                    op.axes.push_back(i);
                }
            }
953
        }
954
        return mm->add_instruction(op, make_contiguous(args[0]));
955
956
    }

Khalique's avatar
Khalique committed
957
958
959
    instruction_ref parse_stridedslice(const std::string&,
                                       const attribute_map& attributes,
                                       std::vector<instruction_ref> args)
960
961
    {
        op::slice op;
Khalique's avatar
Khalique committed
962
963
964
965
        auto starts              = args[1]->eval().get<int32_t>().to_vector();
        auto ends                = args[2]->eval().get<int32_t>().to_vector();
        auto l0                  = args[0];
        size_t num_axes          = l0->get_shape().lens().size();
966
        std::vector<size_t> axes = l0->get_shape().lens();
967

Khalique's avatar
Khalique committed
968
969
970
971
        op.starts = std::vector<int64_t>(starts.begin(), starts.end());
        op.ends   = std::vector<int64_t>(ends.begin(), ends.end());
        op.axes   = std::vector<int64_t>(num_axes);
        std::iota(op.axes.begin(), op.axes.end(), 0);
Khalique's avatar
Khalique committed
972
973
        uint32_t begin_mask       = 0;
        uint32_t end_mask         = 0;
974
        uint32_t shrink_axis_mask = 0;
Khalique's avatar
Khalique committed
975
        uint32_t bitwise_compare  = 1;
976
977
        std::vector<int64_t> squeeze_axes;

Khalique's avatar
Khalique committed
978
979
980
981
982
983
        if(contains(attributes, "begin_mask"))
            begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());

        if(contains(attributes, "end_mask"))
            end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());

984
        if(contains(attributes, "shrink_axis_mask"))
985
            shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
986

Khalique's avatar
Khalique committed
987
        std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
Khalique's avatar
Khalique committed
988
        std::vector<int64_t> end_axes   = get_axes_from_mask(num_axes, end_mask);
Khalique's avatar
Khalique committed
989
990
991
992
993
994
995
996
997
998
999
1000
1001

        for(size_t i = 0; i < num_axes; i++)
        {
            if(begin_axes.at(i) == 1)
            {
                op.starts.at(i) = 0;
            }
            if(end_axes.at(i) == 1)
            {
                op.ends.at(i) = axes.at(i);
            }
        }

1002
        auto l1 = mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
1003
        if(shrink_axis_mask == 0)
1004
            return l1;
Khalique's avatar
Khalique committed
1005

Khalique's avatar
Khalique committed
1006
        for(size_t i = 0; i < num_axes; i++)
1007
        {
1008
            // the LSB corresponds to axis 0 when determining which axes to squeeze
Khalique's avatar
Khalique committed
1009
            if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
1010
1011
                squeeze_axes.push_back(i);
        }
Khalique's avatar
Khalique committed
1012

1013
        return mm->add_instruction(op::squeeze{squeeze_axes}, l1);
1014
1015
    }

1016
1017
1018
    instruction_ref parse_transpose(const std::string&,
                                    const attribute_map&,
                                    std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
1019
1020
1021
1022
1023
    {
        auto perm = args[1]->eval().get<int32_t>().to_vector();
        op::transpose op;
        op.dims = std::vector<int64_t>(perm.begin(), perm.end());

1024
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
1025
1026
    }

Khalique's avatar
Khalique committed
1027
1028
1029
1030
1031
    void parse_graph(const tensorflow::GraphDef& graph)
    {
        nodes = get_nodes(graph, input_nodes);
        for(auto&& input : input_nodes)
        {
Khalique's avatar
Khalique committed
1032
            const std::string& name   = input.name();
Khalique's avatar
Khalique committed
1033
            attribute_map input_attrs = get_attributes(input);
Khalique's avatar
Khalique committed
1034
1035
            shape::type_t shape_type  = parse_type(input_attrs.at("dtype").type());
            std::vector<size_t> dims  = parse_dims(input_attrs.at("shape").shape());
1036
            if(is_nhwc and dims.size() >= 4)
1037
            {
1038
                reorder_data(dims);
1039
            }
1040
1041
1042
            std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
                return static_cast<int>(dim) <= 0 ? batch_size : dim;
            });
Khalique's avatar
Khalique committed
1043
            shape s            = shape{shape_type, dims};
1044
            instructions[name] = to_nhwc(mm->add_parameter(name, s));
Khalique's avatar
Khalique committed
1045
1046
1047
        }
        for(auto&& p : nodes)
        {
1048
            this->parse_node(p.first);
Khalique's avatar
Khalique committed
1049
        }
1050
1051
1052

        // Needs to add a ret instruction at the end of
        // the program
Khalique's avatar
Khalique committed
1053
1054
1055
1056
1057
1058
1059
    }

    void parse_node(const std::string& name)
    {
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
Khalique's avatar
Khalique committed
1060
1061
1062
            // assert ops ignored
            if(node.op() == "Assert" or contains(name, "Assert"))
                return;
kahmed10's avatar
kahmed10 committed
1063
1064
1065
            // noOps ignored
            if(node.op() == "NoOp" or contains(name, "NoOp"))
                return;
Khalique's avatar
Khalique committed
1066
1067
1068
1069
            std::vector<instruction_ref> args;

            for(auto&& input : node.input())
            {
Khalique's avatar
Khalique committed
1070
1071
1072
                // control dependencies (signified by ^ before the name) are ignored
                if(contains(input, "^"))
                    continue;
Khalique's avatar
Khalique committed
1073
1074
                if(nodes.count(input) > 0)
                {
kahmed10's avatar
kahmed10 committed
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
                    std::string iname;
                    // input was from a node with multiple outputs
                    if(contains(input, ':'))
                    {
                        iname = input.substr(0, input.find(':'));
                    }
                    else
                    {
                        iname = get_name(nodes.at(input));
                    }
Khalique's avatar
Khalique committed
1085
1086
                    assert(name != iname);
                    this->parse_node(iname);
kahmed10's avatar
kahmed10 committed
1087
                    args.push_back(instructions.at(input));
Khalique's avatar
Khalique committed
1088
1089
1090
1091
1092
1093
                }
                else
                {
                    args.push_back(instructions.at(input));
                }
            }
kahmed10's avatar
kahmed10 committed
1094
1095

            std::vector<instruction_ref> result;
Khalique's avatar
Khalique committed
1096
1097
            if(ops.count(node.op()) == 0)
            {
1098
                result.push_back(mm->add_instruction(op::unknown{node.op()}, args));
Khalique's avatar
Khalique committed
1099
1100
1101
            }
            else
            {
kahmed10's avatar
kahmed10 committed
1102
1103
1104
1105
1106
1107
1108
1109
1110
                result = ops[node.op()](get_attributes(node), args);
            }

            assert(!result.empty());
            // First output has no ":" delimiter
            instructions[name] = result.front();
            for(size_t i = 1; i < result.size(); i++)
            {
                instructions[name + ":" + std::to_string(i)] = result.at(i);
Khalique's avatar
Khalique committed
1111
1112
1113
1114
            }
        }
    }

1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
    void parse_from(std::istream& is)
    {
        tensorflow::GraphDef graph;
        if(graph.ParseFromIstream(&is))
        {
            this->parse_graph(graph);
        }
        else
        {
            throw std::runtime_error("Failed reading tf file");
        }
    }

Khalique's avatar
Khalique committed
1128
1129
1130
    static attribute_map get_attributes(const tensorflow::NodeDef& node)
    {
        attribute_map result;
Khalique's avatar
Khalique committed
1131
        for(auto&& attr : node.attr())
Khalique's avatar
Khalique committed
1132
1133
1134
1135
1136
1137
        {
            result[attr.first] = attr.second;
        }
        return result;
    }

Khalique's avatar
Khalique committed
1138
    static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
Khalique's avatar
Khalique committed
1139

Khalique's avatar
Khalique committed
1140
1141
    static node_map get_nodes(const tensorflow::GraphDef& graph,
                              std::vector<tensorflow::NodeDef>& input_nodes)
Khalique's avatar
Khalique committed
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
    {
        node_map result;
        for(auto&& node : graph.node())
        {
            auto node_name = get_name(node);
            // assume each node in graph has an associated name
            if(node_name.empty())
                MIGRAPHX_THROW("tf node with no name found");
            result[node_name] = node;
            if(node.op() == "Placeholder")
            {
                input_nodes.push_back(node);
            }
        }
        return result;
    }

    static shape::type_t parse_type(const tensorflow::DataType t)
    {
        shape::type_t shape_type{};
        switch(t)
        {
        case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
        case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
        case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
        case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
        case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
Paul's avatar
Paul committed
1169
1170
1171
1172
        case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
        case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
        case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
        case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
Paul's avatar
Paul committed
1173
        case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
1174
1175
1176

        case tensorflow::DataType::DT_INVALID:
        case tensorflow::DataType::DT_UINT8:
Khalique's avatar
Khalique committed
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        case tensorflow::DataType::DT_STRING:
        case tensorflow::DataType::DT_COMPLEX64:
        case tensorflow::DataType::DT_BOOL:
        case tensorflow::DataType::DT_QINT8:
        case tensorflow::DataType::DT_QUINT8:
        case tensorflow::DataType::DT_QINT32:
        case tensorflow::DataType::DT_BFLOAT16:
        case tensorflow::DataType::DT_QINT16:
        case tensorflow::DataType::DT_QUINT16:
        case tensorflow::DataType::DT_COMPLEX128:
        case tensorflow::DataType::DT_RESOURCE:
        case tensorflow::DataType::DT_VARIANT:
Khalique's avatar
Khalique committed
1189
        // tf pb should not use these types
Paul's avatar
Paul committed
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
        case tensorflow::DataType::DT_FLOAT_REF:
        case tensorflow::DataType::DT_DOUBLE_REF:
        case tensorflow::DataType::DT_INT32_REF:
        case tensorflow::DataType::DT_UINT8_REF:
        case tensorflow::DataType::DT_INT16_REF:
        case tensorflow::DataType::DT_INT8_REF:
        case tensorflow::DataType::DT_STRING_REF:
        case tensorflow::DataType::DT_COMPLEX64_REF:
        case tensorflow::DataType::DT_INT64_REF:
        case tensorflow::DataType::DT_BOOL_REF:
        case tensorflow::DataType::DT_QINT8_REF:
        case tensorflow::DataType::DT_QUINT8_REF:
        case tensorflow::DataType::DT_QINT32_REF:
        case tensorflow::DataType::DT_BFLOAT16_REF:
        case tensorflow::DataType::DT_QINT16_REF:
        case tensorflow::DataType::DT_QUINT16_REF:
        case tensorflow::DataType::DT_UINT16_REF:
        case tensorflow::DataType::DT_COMPLEX128_REF:
        case tensorflow::DataType::DT_HALF_REF:
        case tensorflow::DataType::DT_RESOURCE_REF:
        case tensorflow::DataType::DT_VARIANT_REF:
        case tensorflow::DataType::DT_UINT32_REF:
        case tensorflow::DataType::DT_UINT64_REF:
Paul's avatar
Paul committed
1213
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
Khalique's avatar
Khalique committed
1214
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
Khalique's avatar
Khalique committed
1215
1216
1217
1218
        }
        return shape_type;
    }

Khalique's avatar
Khalique committed
1219
    static literal parse_tensor(const tensorflow::TensorProto& t)
Khalique's avatar
Khalique committed
1220
1221
    {
        std::vector<size_t> dims = parse_dims(t.tensor_shape());
1222
        size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
Khalique's avatar
Khalique committed
1223
1224
        if(!t.tensor_content().empty()) // has raw data
        {
Khalique's avatar
Khalique committed
1225
            const std::string& s = t.tensor_content();
Khalique's avatar
Khalique committed
1226
1227
            switch(t.dtype())
            {
Khalique's avatar
Khalique committed
1228
1229
            case tensorflow::DataType::DT_FLOAT:
                return literal{{shape::float_type, dims}, s.data()};
Paul's avatar
Paul committed
1230
            case tensorflow::DataType::DT_BOOL:
1231
            case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1232
1233
            case tensorflow::DataType::DT_UINT16:
            case tensorflow::DataType::DT_INT16:
1234
                return literal{{shape::int16_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1235
1236
1237
1238
            case tensorflow::DataType::DT_INT32:
                return literal{{shape::int32_type, dims}, s.data()};
            case tensorflow::DataType::DT_INT64:
                return literal{{shape::int64_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1239
            case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1240
1241
            case tensorflow::DataType::DT_DOUBLE:
                return literal{{shape::double_type, dims}, s.data()};
Paul's avatar
Paul committed
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
            case tensorflow::DataType::DT_INVALID:
            case tensorflow::DataType::DT_UINT8:
            case tensorflow::DataType::DT_STRING:
            case tensorflow::DataType::DT_UINT32:
            case tensorflow::DataType::DT_UINT64:
            case tensorflow::DataType::DT_COMPLEX64:
            case tensorflow::DataType::DT_COMPLEX128:
            case tensorflow::DataType::DT_QINT8:
            case tensorflow::DataType::DT_QUINT8:
            case tensorflow::DataType::DT_QINT32:
            case tensorflow::DataType::DT_BFLOAT16:
            case tensorflow::DataType::DT_QINT16:
            case tensorflow::DataType::DT_QUINT16:
            case tensorflow::DataType::DT_RESOURCE:
            case tensorflow::DataType::DT_VARIANT:
            case tensorflow::DataType::DT_FLOAT_REF:
            case tensorflow::DataType::DT_DOUBLE_REF:
            case tensorflow::DataType::DT_INT32_REF:
            case tensorflow::DataType::DT_UINT8_REF:
            case tensorflow::DataType::DT_INT16_REF:
            case tensorflow::DataType::DT_INT8_REF:
            case tensorflow::DataType::DT_STRING_REF:
            case tensorflow::DataType::DT_COMPLEX64_REF:
            case tensorflow::DataType::DT_INT64_REF:
            case tensorflow::DataType::DT_BOOL_REF:
            case tensorflow::DataType::DT_QINT8_REF:
            case tensorflow::DataType::DT_QUINT8_REF:
            case tensorflow::DataType::DT_QINT32_REF:
            case tensorflow::DataType::DT_BFLOAT16_REF:
            case tensorflow::DataType::DT_QINT16_REF:
            case tensorflow::DataType::DT_QUINT16_REF:
            case tensorflow::DataType::DT_UINT16_REF:
            case tensorflow::DataType::DT_COMPLEX128_REF:
            case tensorflow::DataType::DT_HALF_REF:
            case tensorflow::DataType::DT_RESOURCE_REF:
            case tensorflow::DataType::DT_VARIANT_REF:
            case tensorflow::DataType::DT_UINT32_REF:
            case tensorflow::DataType::DT_UINT64_REF:
Khalique's avatar
Khalique committed
1280
1281
1282
            case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
            case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
                throw std::runtime_error("");
Khalique's avatar
Khalique committed
1283
1284
1285
1286
1287
1288
            }
            MIGRAPHX_THROW("Invalid tensor type");
        }
        switch(t.dtype())
        {
        case tensorflow::DataType::DT_FLOAT:
Khalique's avatar
Khalique committed
1289
1290
            return create_literal(
                shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
Khalique's avatar
Khalique committed
1291
        case tensorflow::DataType::DT_INT8:
1292
            return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1293
        case tensorflow::DataType::DT_UINT16:
1294
            return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1295
        case tensorflow::DataType::DT_INT16:
1296
            return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1297
        case tensorflow::DataType::DT_INT32:
1298
            return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1299
        case tensorflow::DataType::DT_INT64:
Khalique's avatar
Khalique committed
1300
1301
            return create_literal(
                shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
Khalique's avatar
Khalique committed
1302
        case tensorflow::DataType::DT_BOOL:
1303
            return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
Khalique's avatar
Khalique committed
1304
        case tensorflow::DataType::DT_HALF:
Khalique's avatar
Khalique committed
1305
        {
1306
1307
            std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
            std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
Khalique's avatar
Khalique committed
1308
1309
1310
1311
1312
            std::vector<half> data_half;
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1313
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1314
        }
Khalique's avatar
Khalique committed
1315
        case tensorflow::DataType::DT_DOUBLE:
Khalique's avatar
Khalique committed
1316
            return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
Paul's avatar
Paul committed
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
        case tensorflow::DataType::DT_INVALID:
        case tensorflow::DataType::DT_UINT8:
        case tensorflow::DataType::DT_STRING:
        case tensorflow::DataType::DT_UINT32:
        case tensorflow::DataType::DT_UINT64:
        case tensorflow::DataType::DT_COMPLEX64:
        case tensorflow::DataType::DT_COMPLEX128:
        case tensorflow::DataType::DT_QINT8:
        case tensorflow::DataType::DT_QUINT8:
        case tensorflow::DataType::DT_QINT32:
        case tensorflow::DataType::DT_BFLOAT16:
        case tensorflow::DataType::DT_QINT16:
        case tensorflow::DataType::DT_QUINT16:
        case tensorflow::DataType::DT_RESOURCE:
        case tensorflow::DataType::DT_VARIANT:
        case tensorflow::DataType::DT_FLOAT_REF:
        case tensorflow::DataType::DT_DOUBLE_REF:
        case tensorflow::DataType::DT_INT32_REF:
        case tensorflow::DataType::DT_UINT8_REF:
        case tensorflow::DataType::DT_INT16_REF:
        case tensorflow::DataType::DT_INT8_REF:
        case tensorflow::DataType::DT_STRING_REF:
        case tensorflow::DataType::DT_COMPLEX64_REF:
        case tensorflow::DataType::DT_INT64_REF:
        case tensorflow::DataType::DT_BOOL_REF:
        case tensorflow::DataType::DT_QINT8_REF:
        case tensorflow::DataType::DT_QUINT8_REF:
        case tensorflow::DataType::DT_QINT32_REF:
        case tensorflow::DataType::DT_BFLOAT16_REF:
        case tensorflow::DataType::DT_QINT16_REF:
        case tensorflow::DataType::DT_QUINT16_REF:
        case tensorflow::DataType::DT_UINT16_REF:
        case tensorflow::DataType::DT_COMPLEX128_REF:
        case tensorflow::DataType::DT_HALF_REF:
        case tensorflow::DataType::DT_RESOURCE_REF:
        case tensorflow::DataType::DT_VARIANT_REF:
        case tensorflow::DataType::DT_UINT32_REF:
        case tensorflow::DataType::DT_UINT64_REF:
Khalique's avatar
Khalique committed
1355
1356
1357
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
            throw std::runtime_error("");
Khalique's avatar
Khalique committed
1358
1359
1360
1361
        }
        MIGRAPHX_THROW("Invalid tensor type");
    }

1362
    template <class T>
Khalique's avatar
Khalique committed
1363
    static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
1364
                                        const size_t& shape_size)
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
    {
        std::vector<T> data_vals(shape_size);
        // check if shape has enough data values given existing fields
        if(data.size() == 1)
        {
            std::fill(data_vals.begin(), data_vals.end(), data[0]);
        }
        else
            copy(data.begin(), data.end(), std::back_inserter(data_vals));
        return data_vals;
    }

Khalique's avatar
Khalique committed
1377
1378
1379
1380
    static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
    {
        std::vector<size_t> dims;
        auto input_dims = s.dim();
Khalique's avatar
Khalique committed
1381
1382
1383
        std::transform(input_dims.begin(),
                       input_dims.end(),
                       std::back_inserter(dims),
Paul's avatar
Paul committed
1384
                       [](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); });
Khalique's avatar
Khalique committed
1385
1386
        return dims;
    }
1387
1388

    template <class T>
Khalique's avatar
Khalique committed
1389
    static literal
1390
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
1391
    {
Khalique's avatar
Khalique committed
1392
        // assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
1393
        if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
1394
            return literal{{shape_type}, data};
1395
1396
        return literal{{shape_type, dims}, data};
    }
Khalique's avatar
Khalique committed
1397
1398
};

1399
program parse_tf(const std::string& name, tf_options options)
Khalique's avatar
Khalique committed
1400
1401
1402
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    tf_parser parser;
1403
1404
    parser.is_nhwc    = options.is_nhwc;
    parser.batch_size = options.batch_size;
Khalique's avatar
Khalique committed
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419

#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
1420
    parser.to_nchw(std::prev(parser.mm->end()));
Khalique's avatar
Khalique committed
1421
1422
1423
1424
1425
    return std::move(parser.prog);
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx