tf.cpp 54 KB
Newer Older
Khalique's avatar
Khalique committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <graph.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include <functional>
#include <array>
#include <utility>
#include <vector>

#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
Khalique's avatar
Khalique committed
20
#include <migraphx/pad_calc.hpp>
Khalique's avatar
Khalique committed
21
22
23
24
25
26
27

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct tf_parser
{
    using attribute_map = std::unordered_map<std::string, tensorflow::AttrValue>;
Paul's avatar
Paul committed
28
    using node_map      = std::map<std::string, tensorflow::NodeDef>;
kahmed10's avatar
kahmed10 committed
29
30
    using op_func =
        std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
Khalique's avatar
Khalique committed
31

Khalique's avatar
Khalique committed
32
33
34
    node_map nodes;
    std::vector<tensorflow::NodeDef> input_nodes;
    std::unordered_map<std::string, instruction_ref> instructions;
35
    program prog            = program();
36
    module* mm              = prog.get_main_module();
37
38
    bool is_nhwc            = true;
    unsigned int batch_size = 1;
Khalique's avatar
Khalique committed
39
40
41

    std::unordered_map<std::string, op_func> ops;

Paul's avatar
Paul committed
42
    bool should_transpose(instruction_ref ins) const
Paul's avatar
Paul committed
43
44
45
46
    {
        return is_nhwc and ins->get_shape().lens().size() == 4;
    }

47
    instruction_ref to_nhwc(instruction_ref ins) const
Paul's avatar
Paul committed
48
    {
Paul's avatar
Paul committed
49
        if(should_transpose(ins))
50
            return mm->add_instruction(op::transpose{{0, 2, 3, 1}}, ins);
Paul's avatar
Paul committed
51
52
53
        return ins;
    }

54
    instruction_ref to_nchw(instruction_ref ins) const
Paul's avatar
Paul committed
55
    {
Paul's avatar
Paul committed
56
        if(should_transpose(ins))
57
            return mm->add_instruction(op::transpose{{0, 3, 1, 2}}, ins);
Paul's avatar
Paul committed
58
59
60
        return ins;
    }

61
    instruction_ref to_kcxy(instruction_ref ins) const
Paul's avatar
Paul committed
62
    {
Paul's avatar
Paul committed
63
        if(should_transpose(ins))
64
            return mm->add_instruction(op::transpose{{3, 2, 0, 1}}, ins);
Paul's avatar
Paul committed
65
66
67
        return ins;
    }

68
    instruction_ref make_contiguous(instruction_ref ins) const
Paul's avatar
Paul committed
69
    {
Paul's avatar
Paul committed
70
        if(ins->get_shape().standard())
Paul's avatar
Paul committed
71
72
            return ins;
        else
73
            return mm->add_instruction(op::contiguous{}, ins);
Paul's avatar
Paul committed
74
75
76
77
78
    }

    std::vector<instruction_ref> to_nchw(const std::vector<instruction_ref>& args)
    {
        std::vector<instruction_ref> result(args.size());
Paul's avatar
Paul committed
79
        std::transform(
Paul's avatar
Paul committed
80
            args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nchw(ins); });
Paul's avatar
Paul committed
81
82
83
        return result;
    }

kahmed10's avatar
kahmed10 committed
84
85
86
87
88
89
90
91
    std::vector<instruction_ref> to_nhwc(const std::vector<instruction_ref>& args)
    {
        std::vector<instruction_ref> result(args.size());
        std::transform(
            args.begin(), args.end(), result.begin(), [&](auto ins) { return this->to_nhwc(ins); });
        return result;
    }

Khalique's avatar
Khalique committed
92
    std::vector<size_t>
93
    parse_axes(const attribute_map& attributes, const std::string& s, const size_t num_dims) const
94
    {
95
96
97
        auto attrs = attributes.at(s).list().i();
        std::vector<size_t> axes;
        copy(attrs.begin(), attrs.end(), std::back_inserter(axes));
Khalique's avatar
Khalique committed
98
        if(is_nhwc)
99
        {
Khalique's avatar
Khalique committed
100
            std::transform(axes.begin(), axes.end(), axes.begin(), [&](size_t axis) {
Khalique's avatar
Khalique committed
101
                return parse_axis(axis, num_dims);
Khalique's avatar
Khalique committed
102
            });
103
104
105
106
        }
        return axes;
    }

Khalique's avatar
Khalique committed
107
    template <class T>
108
    std::vector<T> parse_axes(std::vector<T> axes, const size_t num_dims) const
Khalique's avatar
Khalique committed
109
110
111
    {
        if(is_nhwc)
        {
112
            std::vector<T> new_axes;
Khalique's avatar
Khalique committed
113
114
115
            std::transform(axes.begin(),
                           axes.end(),
                           std::back_inserter(new_axes),
Khalique's avatar
Khalique committed
116
                           [&](size_t axis) { return parse_axis(axis, num_dims); });
117
            return new_axes;
Khalique's avatar
Khalique committed
118
        }
119
        return axes;
Khalique's avatar
Khalique committed
120
121
    }

Khalique's avatar
Khalique committed
122
123
124
    // tf stores certain attributes such as strides, dilations, as a 4D input.
    // The first and last dims are equal to 1, and the relevant data is in dims 2 and 3.
    // This helper function reorders the data to store for the respective operator member variables.
125
    template <class T>
126
    void reorder_data(std::vector<T>& prev_data) const
127
128
    {
        std::vector<T> new_data(prev_data.size());
129
        for(size_t i = 0; i < new_data.size(); i++)
130
        {
Khalique's avatar
Khalique committed
131
            auto new_idx         = parse_axis(i, new_data.size());
132
            new_data.at(new_idx) = prev_data.at(i);
133
        }
134
135
136
137
        prev_data = new_data;
    }

    template <class T>
138
    T parse_axis(const T& dim, const size_t num_dims) const
139
    {
Khalique's avatar
Khalique committed
140
        T new_dim = dim;
Khalique's avatar
Khalique committed
141
        if(is_nhwc and num_dims >= 4)
142
143
144
        {
            switch(dim)
            {
Khalique's avatar
Khalique committed
145
146
147
148
149
            case 0: new_dim = 0; break;
            case 1: new_dim = 2; break;
            case 2: new_dim = 3; break;
            case 3: new_dim = 1; break;
            default: break;
150
151
            }
        }
Khalique's avatar
Khalique committed
152
        return new_dim;
153
154
    }

155
156
157
158
159
160
161
    std::vector<int64_t> get_axes(size_t num_axes) const
    {
        std::vector<int64_t> axes(num_axes);
        std::iota(axes.begin(), axes.end(), 0);
        return axes;
    }

Khalique's avatar
Khalique committed
162
    std::vector<int64_t> get_axes_from_mask(const size_t num_axes, const uint32_t mask)
Khalique's avatar
Khalique committed
163
    {
Khalique's avatar
Khalique committed
164
        uint32_t bitwise_compare = 1;
Khalique's avatar
Khalique committed
165
166
167
168
169
170
171
172
173
174
175
176
        std::vector<int64_t> axes;
        for(size_t i = 0; i < num_axes; i++)
        {
            // the LSB corresponds to axis 0 when determining which axes to begin
            if(((mask >> i) & bitwise_compare) == 1)
                axes.push_back(1);
            else
                axes.push_back(0);
        }
        return axes;
    }

Khalique's avatar
Khalique committed
177
178
    tf_parser()
    {
Khalique's avatar
Khalique committed
179
        add_generic_op("All", op::identity{});
Khalique's avatar
Khalique committed
180
        add_generic_op("Identity", op::identity{});
Khalique's avatar
Khalique committed
181
        add_generic_op("LessEqual", op::identity{});
Khalique's avatar
Khalique committed
182
        add_generic_op("Relu", op::relu{});
kahmed10's avatar
kahmed10 committed
183
        // add_generic_op("Relu6", op::clip{6.0, 0.0});
Khalique's avatar
Khalique committed
184
        add_generic_op("Rsqrt", op::rsqrt{});
Khalique's avatar
Khalique committed
185
        add_generic_op("Tanh", op::tanh{});
Khalique's avatar
Khalique committed
186
        add_generic_op("StopGradient", op::identity{});
Khalique's avatar
Khalique committed
187

188
        add_binary_op("Add", op::add{});
Khalique's avatar
Khalique committed
189
        add_binary_op("Mul", op::mul{});
Khalique's avatar
Khalique committed
190
        add_binary_op("Pow", op::pow{});
Khalique's avatar
Khalique committed
191
        add_binary_op("SquaredDifference", op::sqdiff{});
Khalique's avatar
Khalique committed
192
        add_binary_op("Sub", op::sub{});
Khalique's avatar
Khalique committed
193

194
195
        add_mem_op("ArgMax", &tf_parser::parse_arg_op<op::argmax>, false);
        add_mem_op("ArgMin", &tf_parser::parse_arg_op<op::argmin>, false);
196
        add_mem_op("AvgPool", &tf_parser::parse_pooling);
Khalique's avatar
Khalique committed
197
        add_mem_op("BatchMatMul", &tf_parser::parse_matmul, false);
Khalique's avatar
Khalique committed
198
        add_mem_op("BatchMatMulV2", &tf_parser::parse_matmul, false);
199
        add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
Khalique's avatar
Khalique committed
200
        add_mem_op("Cast", &tf_parser::parse_cast, false);
Paul's avatar
Paul committed
201
        add_mem_op("ConcatV2", &tf_parser::parse_concat, false);
Khalique's avatar
Khalique committed
202
        add_mem_op("Const", &tf_parser::parse_constant);
Paul's avatar
Paul committed
203
        add_mem_op("Conv2D", &tf_parser::parse_conv);
Paul's avatar
Paul committed
204
        add_mem_op("DepthwiseConv2dNative", &tf_parser::parse_depthwiseconv);
205
        add_mem_op("ExpandDims", &tf_parser::parse_expanddims, false);
Khalique's avatar
Khalique committed
206
        add_mem_op("FusedBatchNorm", &tf_parser::parse_batchnorm);
Khalique's avatar
Khalique committed
207
        add_mem_op("GatherV2", &tf_parser::parse_gather, false);
Paul's avatar
Paul committed
208
        add_mem_op("MatMul", &tf_parser::parse_matmul, false);
209
        add_mem_op("MaxPool", &tf_parser::parse_pooling);
Khalique's avatar
Khalique committed
210
        add_mem_op("Mean", &tf_parser::parse_mean, false);
Khalique's avatar
Khalique committed
211
        add_mem_op("OneHot", &tf_parser::parse_onehot, false);
Paul's avatar
Paul committed
212
        add_mem_op("Pack", &tf_parser::parse_pack, false);
Paul's avatar
Paul committed
213
        add_mem_op("Pad", &tf_parser::parse_pad);
kahmed10's avatar
kahmed10 committed
214
        add_mem_op("Relu6", &tf_parser::parse_relu6);
Paul's avatar
Paul committed
215
        add_mem_op("Reshape", &tf_parser::parse_reshape, false);
216
        add_mem_op("Shape", &tf_parser::parse_shape, false);
Khalique's avatar
Khalique committed
217
        add_mem_op("Slice", &tf_parser::parse_slice, false);
kahmed10's avatar
kahmed10 committed
218
219
        add_mem_op("Split", &tf_parser::parse_split, false);
        add_mem_op("SplitV", &tf_parser::parse_split, false);
Khalique's avatar
Khalique committed
220
        add_mem_op("Softmax", &tf_parser::parse_softmax<op::softmax>, false);
Paul's avatar
Paul committed
221
        add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
222
        add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
Khalique's avatar
Khalique committed
223
        add_mem_op("Transpose", &tf_parser::parse_transpose, false);
Khalique's avatar
Khalique committed
224
225
    }

226
    template <class F>
kahmed10's avatar
kahmed10 committed
227
    void add_op(const std::string& name, F f, bool transpose = true)
228
    {
Paul's avatar
Paul committed
229
        if(transpose)
Paul's avatar
Paul committed
230
        {
kahmed10's avatar
kahmed10 committed
231
232
233
234
235
236
            ops.emplace(
                name,
                op_func{
                    [=](const attribute_map& attributes, const std::vector<instruction_ref>& args) {
                        return std::vector<instruction_ref>{to_nhwc(f(attributes, to_nchw(args)))};
                    }});
Paul's avatar
Paul committed
237
238
239
        }
        else
        {
kahmed10's avatar
kahmed10 committed
240
241
242
243
244
            ops.emplace(name,
                        op_func{[=](const attribute_map& attributes,
                                    const std::vector<instruction_ref>& args) {
                            return std::vector<instruction_ref>{f(attributes, args)};
                        }});
Paul's avatar
Paul committed
245
        }
246
247
    }

Khalique's avatar
Khalique committed
248
    template <class F>
Paul's avatar
Paul committed
249
    void add_mem_op(std::string name, F f, bool transpose = true)
Khalique's avatar
Khalique committed
250
    {
Paul's avatar
Paul committed
251
252
253
254
255
        add_op(name,
               [=](auto&&... xs) {
                   return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
               },
               transpose);
Khalique's avatar
Khalique committed
256
257
258
259
260
    }

    template <class T>
    void add_binary_op(std::string name, T x)
    {
Paul's avatar
Paul committed
261
262
263
264
265
266
267
268
269
        add_op(name,
               [this, x](const attribute_map&, std::vector<instruction_ref> args) {
                   if(args.size() != 2)
                       MIGRAPHX_THROW("binary operators should have 2 operands");
                   // TODO
                   // if(contains(attributes, "data_format"))
                   // {
                   //     if(is_nhwc)
                   //     {
270
                   //         l0 = mm->add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
Paul's avatar
Paul committed
271
272
273
274
275
                   //     }
                   // }
                   return add_broadcastable_binary_op(args[0], args[1], x);
               },
               false);
Khalique's avatar
Khalique committed
276
277
278
279
280
    }

    template <class T>
    instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
    {
Khalique's avatar
Khalique committed
281
        if(arg0->get_shape().lens() != arg1->get_shape().lens())
Khalique's avatar
Khalique committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        {
            // Example:
            // s0 = (3,2,4,5) and s1 = (2,1,1)
            //
            // In this case we need to broadcast (:,1,1) portion of
            // s1 plus broadcast the 1st dimension of s1
            // giving output_lens = (3,2,4,5)
            //
            // Another example:
            // s0 = (3,2,1,5) and s1 = (2,7,5)
            // In this case we need to broadcast the (:,:,1:,:) axis
            // of s0 plus the 1st dimension of s1 giving
            // output_lens = (3,2,7,5)
            //
            // Get lengths for both arguments
297
298
            const std::vector<size_t>* s0 = &arg0->get_shape().lens();
            const std::vector<size_t>* s1 = &arg1->get_shape().lens();
Khalique's avatar
Khalique committed
299
300
301
302
303

            // Make sure s0 is the smaller size
            if(s0->size() > s1->size())
                std::swap(s0, s1);

304
            std::vector<size_t> output_lens(*s1);
Khalique's avatar
Khalique committed
305
306
307
308
309
310
311
            auto offset = s1->size() - s0->size();
            std::transform(s0->begin(),
                           s0->end(),
                           s1->begin() + offset,
                           output_lens.begin() + offset,
                           [](auto a, auto b) { return std::max(a, b); });

312
313
314
            auto l0 = mm->add_instruction(op::multibroadcast{output_lens}, arg0);
            auto l1 = mm->add_instruction(op::multibroadcast{output_lens}, arg1);
            return to_nhwc(mm->add_instruction(x, to_nchw(l0), to_nchw(l1)));
Khalique's avatar
Khalique committed
315
316
317
        }
        else
        {
318
            return to_nhwc(mm->add_instruction(x, {to_nchw(arg0), to_nchw(arg1)}));
Khalique's avatar
Khalique committed
319
320
321
322
        }
    }

    template <class T>
Paul's avatar
Paul committed
323
    void add_generic_op(std::string name, T x, bool transpose = true)
Khalique's avatar
Khalique committed
324
    {
Paul's avatar
Paul committed
325
326
        add_op(name,
               [this, x](const attribute_map&, std::vector<instruction_ref> args) {
327
                   return mm->add_instruction(x, args);
Paul's avatar
Paul committed
328
329
               },
               transpose);
Khalique's avatar
Khalique committed
330
331
    }

332
333
334
335
336
337
    template <class Op>
    instruction_ref
    parse_arg_op(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
    {
        int64_t axis = 0;
        axis         = args[1]->eval().at<int64_t>();
338
339
        auto ins     = mm->add_instruction(Op{axis}, args.front());
        return mm->add_instruction(op::squeeze{{axis}}, ins);
340
341
    }

342
343
344
    instruction_ref parse_batchnorm(const std::string&,
                                    attribute_map attributes,
                                    std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
345
    {
Khalique's avatar
Khalique committed
346
347
348
        float epsilon                                     = 1e-5f;
        float momentum                                    = 0.9f;
        op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
Khalique's avatar
Khalique committed
349
350
351
352
353
        if(contains(attributes, "epsilon"))
        {
            epsilon = attributes.at("epsilon").f();
        }
        op::batch_norm_inference op{epsilon, momentum, bn_mode};
354
        return mm->add_instruction(op, std::move(args));
Khalique's avatar
Khalique committed
355
356
    }

357
    instruction_ref
358
    parse_biasadd(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
359
    {
360
        uint64_t axis = 1; // assume output of previous layer is in NCHW (broadcast on channel)
361
362
        auto l0 = mm->add_instruction(op::broadcast{axis, args[0]->get_shape().lens()}, args[1]);
        return mm->add_instruction(op::add{}, args[0], l0);
363
364
    }

365
366
367
    instruction_ref parse_cast(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
368
369
    {
        shape::type_t type = parse_type(attributes.at("DstT").type());
370
        return mm->add_instruction(op::convert{type}, std::move(args));
Khalique's avatar
Khalique committed
371
372
    }

373
374
375
    instruction_ref parse_concat(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
376
377
    {
        // get index for axis within args
378
        size_t axis_idx = attributes.at("N").i();
Shucai Xiao's avatar
Shucai Xiao committed
379
        int64_t axis    = args[axis_idx]->eval().at<int64_t>();
Khalique's avatar
Khalique committed
380
        op::concat op{axis};
381
        // return only first N arguments (assuming last index is the axis value)
382
        return mm->add_instruction(
Paul's avatar
Paul committed
383
            op, std::vector<instruction_ref>(args.begin(), args.begin() + args.size() - 1));
Khalique's avatar
Khalique committed
384
385
386
387
    }

    instruction_ref parse_constant(const std::string&,
                                   attribute_map attributes,
388
                                   const std::vector<instruction_ref>&) const
Khalique's avatar
Khalique committed
389
    {
Paul's avatar
Paul committed
390
        literal v = parse_tensor(attributes.at("value").tensor());
391
        return mm->add_literal(v);
Khalique's avatar
Khalique committed
392
393
    }

394
395
396
    instruction_ref parse_conv(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
397
398
399
400
    {
        op::convolution op;
        if(contains(attributes, "strides"))
        {
401
            std::vector<size_t> stride;
402
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
403
            reorder_data(stride);
404
405
            if(stride.size() != 4)
            {
406
                MIGRAPHX_THROW("strides should have 4 values");
407
            }
408
409
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
Khalique's avatar
Khalique committed
410
411
412
        }
        if(contains(attributes, "dilations"))
        {
413
            std::vector<size_t> dilation;
414
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
415
            reorder_data(dilation);
416
417
418
419
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
420
421
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
Khalique's avatar
Khalique committed
422
        }
Khalique's avatar
Khalique committed
423

Paul's avatar
Paul committed
424
        auto weights = to_kcxy(args[1]);
Paul's avatar
Paul committed
425
        auto l0      = args[0];
Khalique's avatar
Khalique committed
426
427
428
429
430
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
431
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
432
433
434
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];
Khalique's avatar
Khalique committed
435
436
437

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
438
439
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
Khalique's avatar
Khalique committed
440
441
442
443

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
444
                    l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
Khalique's avatar
Khalique committed
445
446
447
                }
                else
                {
Khalique's avatar
Khalique committed
448
449
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
450
                }
451
452
453
            }
            else if(pad_mode.find("VALID") != std::string::npos)
            {
454
                op.padding_mode = op::padding_mode_t::valid;
Khalique's avatar
Khalique committed
455
            }
Khalique's avatar
Khalique committed
456
            else if(pad_mode.find("EXPLICIT") != std::string::npos)
Khalique's avatar
Khalique committed
457
            {
458
                std::vector<size_t> padding;
459
                copy(attributes.at("explicit_paddings").list().i(), std::back_inserter(padding));
Khalique's avatar
Khalique committed
460
461
462
463
464
465
466
467
468
469
470
471
                if(padding.size() != 4)
                {
                    MIGRAPHX_THROW("padding should have 4 values");
                }
                if(padding[0] != padding[2] || padding[1] != padding[3])
                {
                    MIGRAPHX_THROW("migraphx does not support asymetric padding");
                }
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
472
        return mm->add_instruction(op, {l0, to_kcxy(args[1])});
Khalique's avatar
Khalique committed
473
474
    }

Khalique's avatar
Khalique committed
475
476
    instruction_ref parse_depthwiseconv(const std::string&,
                                        attribute_map attributes,
477
                                        std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
478
479
480
    {
        op::convolution op;
        size_t num_channels = args[0]->get_shape().lens()[1];
Khalique's avatar
Khalique committed
481
        op.group            = num_channels;
Khalique's avatar
Khalique committed
482

Khalique's avatar
Khalique committed
483
484
        if(contains(attributes, "strides"))
        {
485
            std::vector<size_t> stride;
486
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
487
            reorder_data(stride);
488
489
            if(stride.size() != 4)
            {
490
                MIGRAPHX_THROW("strides should have 4 values");
491
            }
492
493
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
Khalique's avatar
Khalique committed
494
        }
Paul's avatar
Paul committed
495
496

        auto weights = to_kcxy(args[1]);
Khalique's avatar
Khalique committed
497
498
        if(contains(attributes, "dilations"))
        {
499
            std::vector<size_t> dilation;
500
            copy(attributes.at("dilations").list().i(), std::back_inserter(dilation));
501
            reorder_data(dilation);
502
503
504
505
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
506
507
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
Khalique's avatar
Khalique committed
508
509
        }

Khalique's avatar
Khalique committed
510
        auto l0 = args[0];
Khalique's avatar
Khalique committed
511
512
513
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
Khalique's avatar
Khalique committed
514

Khalique's avatar
Khalique committed
515
516
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
517
                op.padding_mode                 = op::padding_mode_t::same;
Khalique's avatar
Khalique committed
518
519
520
521
522
523
                std::vector<size_t> weight_dims = weights->get_shape().lens();
                size_t weight_h                 = weight_dims[2];
                size_t weight_w                 = weight_dims[3];

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
524
525
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
Khalique's avatar
Khalique committed
526
527
528
529

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
530
                    l0 = mm->add_instruction(migraphx::op::pad{padding}, l0);
Khalique's avatar
Khalique committed
531
532
533
                }
                else
                {
Khalique's avatar
Khalique committed
534
535
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
536
                }
Khalique's avatar
Khalique committed
537
            }
Khalique's avatar
Khalique committed
538
            else if(pad_mode.find("VALID") != std::string::npos)
Khalique's avatar
Khalique committed
539
            {
Khalique's avatar
Khalique committed
540
                op.padding_mode = op::padding_mode_t::valid;
Khalique's avatar
Khalique committed
541
542
            }
        }
Khalique's avatar
Khalique committed
543

Khalique's avatar
Khalique committed
544
545
        std::vector<int64_t> new_weights_shape;
        copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));
Khalique's avatar
Khalique committed
546
547
548
549

        // weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
        // out_channels is equal to the multiplier. Adjust by inserting a reshape and
        // setting in_channels to 1
Khalique's avatar
Khalique committed
550
        int64_t multiplier   = new_weights_shape[0];
Khalique's avatar
Khalique committed
551
552
553
        int64_t out_channels = num_channels * multiplier;
        new_weights_shape[0] = out_channels;
        new_weights_shape[1] = 1;
Paul's avatar
Paul committed
554
        // Make sure weights are contiguous before doing reshape
Paul's avatar
Paul committed
555
        auto new_weights =
556
            mm->add_instruction(op::reshape{new_weights_shape}, make_contiguous(weights));
Khalique's avatar
Khalique committed
557

558
        return mm->add_instruction(op, {l0, new_weights});
Khalique's avatar
Khalique committed
559
560
    }

561
562
563
    instruction_ref parse_expanddims(const std::string&,
                                     const attribute_map&,
                                     std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
564
565
    {
        std::vector<size_t> input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
566
        std::vector<int64_t> new_dims(input_dims.begin(), input_dims.end());
Khalique's avatar
Khalique committed
567
        size_t num_dims = input_dims.size();
568
        int32_t dim     = args[1]->eval().at<int32_t>();
Khalique's avatar
Khalique committed
569
570

        if(dim < 0)
Khalique's avatar
Khalique committed
571
572
573
574
575
576
577
        {
            new_dims.insert(new_dims.begin() + (num_dims + dim + 1), 1);
        }
        else
        {
            new_dims.insert(new_dims.begin() + dim, 1);
        }
578
        return mm->add_instruction(op::reshape{new_dims}, args[0]);
Khalique's avatar
Khalique committed
579
580
    }

Khalique's avatar
Khalique committed
581
    instruction_ref
582
    parse_gather(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
583
584
585
    {
        int axis = args[2]->eval().at<int32_t>();
        op::gather op{axis};
586
        return mm->add_instruction(op, {args[0], args[1]});
Khalique's avatar
Khalique committed
587
588
    }

589
590
591
    instruction_ref parse_matmul(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
592
593
594
    {
        bool transa = false;
        bool transb = false;
Khalique's avatar
Khalique committed
595

596
597
598
599
600
601
        if(contains(attributes, "transpose_a"))
        {
            transa = attributes.at("transpose_a").b();
        }
        if(contains(attributes, "transpose_b"))
        {
Khalique's avatar
Khalique committed
602
            transb = attributes.at("transpose_b").b();
603
604
        }

Khalique's avatar
Khalique committed
605
606
607
608
609
610
611
612
613
        if(contains(attributes, "adj_x"))
        {
            transa = attributes.at("adj_x").b();
        }
        if(contains(attributes, "adj_y"))
        {
            transb = attributes.at("adj_y").b();
        }

614
615
616
        std::vector<int64_t> perm(args[0]->get_shape().lens().size());
        std::iota(perm.begin(), perm.end(), int64_t{0});
        // swap the last two elements
Khalique's avatar
Khalique committed
617
        std::iter_swap(perm.end() - 1, perm.end() - 2);
618

619
620
        auto l1 = (transa) ? mm->add_instruction(op::transpose{perm}, args[0]) : args[0];
        auto l2 = (transb) ? mm->add_instruction(op::transpose{perm}, args[1]) : args[1];
621

622
        return mm->add_instruction(op::dot{}, l1, l2);
623
624
    }

625
626
627
    instruction_ref parse_mean(const std::string&,
                               attribute_map attributes,
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
628
    {
Khalique's avatar
Khalique committed
629
630
        bool keep_dims = attributes.at("keep_dims").b();
        auto axes      = args[1]->eval().get<int32_t>().to_vector<int64_t>();
Khalique's avatar
Khalique committed
631
632

        if(keep_dims)
Khalique's avatar
Khalique committed
633
        {
634
            return mm->add_instruction(op::reduce_mean{axes}, args[0]);
635
636
637
        }
        else
        {
638
639
            auto ins = mm->add_instruction(op::reduce_mean{axes}, args[0]);
            return mm->add_instruction(op::squeeze{axes}, ins);
Khalique's avatar
Khalique committed
640
641
642
        }
    }

643
644
645
    instruction_ref parse_onehot(const std::string&,
                                 attribute_map attributes,
                                 std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
646
    {
Khalique's avatar
Khalique committed
647
648
        size_t depth = static_cast<size_t>(args[1]->eval().at<int32_t>());

Khalique's avatar
Khalique committed
649
        int64_t axis    = -1;
Khalique's avatar
Khalique committed
650
651
        float on_value  = args[2]->eval().at<float>();
        float off_value = args[3]->eval().at<float>();
Khalique's avatar
Khalique committed
652

Khalique's avatar
Khalique committed
653
        std::vector<float> depth_input(depth * depth, off_value);
Khalique's avatar
Khalique committed
654
655
        for(int i = 0; i < depth; i++)
        {
Khalique's avatar
Khalique committed
656
            depth_input[depth * i + i] = on_value;
Khalique's avatar
Khalique committed
657
        }
Khalique's avatar
Khalique committed
658

Khalique's avatar
Khalique committed
659
        if(contains(attributes, "axis"))
Khalique's avatar
Khalique committed
660
661
662
            axis = attributes.at("axis").i();
        if(axis == -1)
        {
Khalique's avatar
Khalique committed
663
            shape s{shape::float_type, {depth, depth}};
664
665
            auto l0 = mm->add_literal({s, depth_input});
            return mm->add_instruction(op::gather{0}, {l0, args[0]});
Khalique's avatar
Khalique committed
666
667
668
669
        }
        MIGRAPHX_THROW("MIGraphX does not support axis != -1");
    }

Khalique's avatar
Khalique committed
670
671
    instruction_ref parse_pack(const std::string&,
                               const attribute_map& attributes,
672
                               std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
673
674
675
676
677
678
    {
        // reinterpret as unsqueeze with concat
        std::vector<instruction_ref> unsqueezed_args;
        int64_t axis = 0;
        if(contains(attributes, "axis"))
            axis = attributes.at("axis").i();
679
680
681
        size_t input_size = args.front()->get_shape().lens().size();
        if(axis > input_size)
        {
Khalique's avatar
Khalique committed
682
683
            MIGRAPHX_THROW("TF_PARSER: axis value of " + to_string(axis) +
                           " must be smaller than input size " + to_string(input_size));
684
685
        }

Khalique's avatar
Khalique committed
686
687
688
689
        std::transform(
            args.begin(),
            args.end(),
            std::back_inserter(unsqueezed_args),
690
691
            [&](instruction_ref arg) { return mm->add_instruction(op::unsqueeze{{axis}}, arg); });
        return to_nhwc(mm->add_instruction(op::concat{axis}, unsqueezed_args));
Khalique's avatar
Khalique committed
692
693
    }

Khalique's avatar
Khalique committed
694
    instruction_ref
695
    parse_pad(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
696
697
698
    {
        size_t ndims = args.front()->get_shape().lens().size();

Khalique's avatar
Khalique committed
699
700
        // in tf, the paddings are arranged as a 2d shape (ndims, 2),
        // the last dim contains the left padding and right padding respectively
Khalique's avatar
Khalique committed
701
        std::vector<std::pair<int32_t, int32_t>> pad_per_dim(ndims);
Paul's avatar
Paul committed
702
        auto tf_padding = args[1]->eval().get<int32_t>().to_vector();
Khalique's avatar
Khalique committed
703
        for(size_t i = 0; i < 2 * ndims; i += 2)
Khalique's avatar
Khalique committed
704
        {
Khalique's avatar
Khalique committed
705
706
            pad_per_dim[i / 2].first  = tf_padding[i];
            pad_per_dim[i / 2].second = tf_padding[i + 1];
Khalique's avatar
Khalique committed
707
708
709
710
        }
        reorder_data(pad_per_dim);

        op::pad op;
Khalique's avatar
Khalique committed
711
712
        std::vector<int64_t> pads(ndims * 2);
        for(size_t i = 0; i < ndims; i++)
Khalique's avatar
Khalique committed
713
        {
Khalique's avatar
Khalique committed
714
715
            pads[i]         = pad_per_dim[i].first;
            pads[i + ndims] = pad_per_dim[i].second;
Khalique's avatar
Khalique committed
716
717
        }
        op.pads = pads;
718
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
719
720
    }

721
722
    instruction_ref parse_pooling(const std::string& name,
                                  attribute_map attributes,
723
                                  std::vector<instruction_ref> args) const
724
725
    {
        op::pooling op{starts_with(name, "Max") ? "max" : "average"};
Khalique's avatar
Khalique committed
726

727
728
        if(contains(attributes, "strides"))
        {
729
            std::vector<size_t> stride;
730
            copy(attributes.at("strides").list().i(), std::back_inserter(stride));
731
            reorder_data(stride);
732
733
734
735
            if(stride.size() != 4)
            {
                MIGRAPHX_THROW("strides should have 4 values");
            }
736
737
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
738
739
740
        }
        if(contains(attributes, "ksize"))
        {
741
            std::vector<size_t> ksize;
742
            copy(attributes.at("ksize").list().i(), std::back_inserter(ksize));
743
            reorder_data(ksize);
744
745
746
            if(ksize.size() != 4)
            {
                MIGRAPHX_THROW("ksize should have 4 values");
Khalique's avatar
Khalique committed
747
            }
748
749
            op.lengths[0] = ksize[2];
            op.lengths[1] = ksize[3];
750
        }
Khalique's avatar
Khalique committed
751
752

        auto l0 = args[0];
Khalique's avatar
Khalique committed
753
754
755
756
757
        if(contains(attributes, "padding"))
        {
            const std::string& pad_mode = attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Khalique's avatar
Khalique committed
758
759
                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
760
761
                calculate_padding(0, pads, input_dims[2], op.stride[0], 1, op.lengths[0]);
                calculate_padding(1, pads, input_dims[3], op.stride[1], 1, op.lengths[1]);
Khalique's avatar
Khalique committed
762
763
764
765

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
766
                    l0                           = mm->add_instruction(
Khalique's avatar
Khalique committed
767
                        migraphx::op::pad{padding, std::numeric_limits<float>::lowest()}, l0);
Khalique's avatar
Khalique committed
768
769
770
                }
                else
                {
Khalique's avatar
Khalique committed
771
772
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
Khalique's avatar
Khalique committed
773
                }
Khalique's avatar
Khalique committed
774
775
            }
        }
776
        return mm->add_instruction(op, l0);
777
    }
Khalique's avatar
Khalique committed
778

kahmed10's avatar
kahmed10 committed
779
    instruction_ref
780
    parse_relu6(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
kahmed10's avatar
kahmed10 committed
781
782
    {
        auto input_lens = args[0]->get_shape().lens();
783
784
        auto min_val    = mm->add_literal(0.0f);
        auto max_val    = mm->add_literal(6.0f);
kahmed10's avatar
kahmed10 committed
785

786
787
788
        min_val = mm->add_instruction(op::multibroadcast{input_lens}, min_val);
        max_val = mm->add_instruction(op::multibroadcast{input_lens}, max_val);
        return mm->add_instruction(op::clip{}, args.front(), min_val, max_val);
kahmed10's avatar
kahmed10 committed
789
790
    }

791
    instruction_ref
792
    parse_reshape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
793
794
795
796
    {
        op::reshape op;
        if(args.size() != 2)
            MIGRAPHX_THROW("reshape needs 2 arguments (input, new_shape)");
Khalique's avatar
Khalique committed
797
        auto s = args[1]->eval();
798
        s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
799
        return mm->add_instruction(op, make_contiguous(args[0]));
800
801
    }

802
803
804
    // Use a literal instruction to replace the shape since output of
    // shape operator are literals in migraphx
    instruction_ref
805
    parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
806
    {
807
808
809
810
811
        std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
        std::vector<int32_t> vec_shape(arg_shape.size());
        migraphx::shape s(migraphx::shape::int32_type, {arg_shape.size()});
        std::transform(
            arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) { return i; });
812
        return mm->add_literal(migraphx::literal{s, vec_shape});
Khalique's avatar
Khalique committed
813
814
    }

815
    instruction_ref
816
    parse_slice(const std::string&, const attribute_map&, std::vector<instruction_ref> args) const
817
    {
Khalique's avatar
Khalique committed
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
        op::slice op;
        auto starts     = args[1]->eval().get<int32_t>().to_vector();
        auto size       = args[2]->eval().get<int32_t>().to_vector();
        auto axes       = args[0]->get_shape().lens();
        size_t num_axes = axes.size();

        op.starts = std::vector<int64_t>(starts.begin(), starts.end());
        op.ends   = std::vector<int64_t>(num_axes);
        op.axes   = std::vector<int64_t>(num_axes);
        std::iota(op.axes.begin(), op.axes.end(), 0);
        for(size_t i = 0; i < num_axes; i++)
        {
            if(size[i] == -1)
                op.ends[i] = axes[i];
            else
                op.ends[i] = starts[i] + size[i];
        }
835
        return mm->add_instruction(op, make_contiguous(args[0]));
Khalique's avatar
Khalique committed
836
837
    }

Khalique's avatar
Khalique committed
838
839
840
841
842
    // template to facilitate the logsoftmax later
    template <class Op>
    instruction_ref parse_softmax(const std::string&,
                                  const attribute_map& attributes,
                                  std::vector<instruction_ref> args)
843
    {
Khalique's avatar
Khalique committed
844
        int axis      = -1;
Khalique's avatar
Khalique committed
845
846
847
848
849
850
851
852
853
854
        auto num_dims = args[0]->get_shape().lens().size();
        if(contains(attributes, "axis"))
        {
            axis = static_cast<int>(attributes.at("axis").i());
        }
        if(axis < 0)
        {
            axis += num_dims;
        }

855
        return mm->add_instruction(Op{axis}, make_contiguous(args[0]));
856
857
    }

kahmed10's avatar
kahmed10 committed
858
859
    std::vector<instruction_ref> parse_split(const std::string&,
                                             const attribute_map& attributes,
860
                                             std::vector<instruction_ref> args) const
kahmed10's avatar
kahmed10 committed
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
    {
        bool vector_as_input = args.size() == 3;
        int num_outputs      = 1;
        auto axis_arg        = args[0];
        auto input_arg       = args[1];
        if(vector_as_input)
        {
            input_arg = args[0];
            axis_arg  = args[2];
        }

        if(contains(attributes, "num_split"))
            num_outputs = attributes.at("num_split").i();

        std::vector<int> splits(num_outputs);
        std::vector<int> slice_pos{0};
        if(vector_as_input)
        {
            splits      = args[1]->eval().get<int32_t>().to_vector();
            num_outputs = splits.size();
        }

        assert(num_outputs > 0);

        if(num_outputs == 1)
886
            return std::vector<instruction_ref>{mm->add_instruction(op::identity{}, input_arg)};
kahmed10's avatar
kahmed10 committed
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930

        auto lens     = input_arg->get_shape().lens();
        auto num_dims = lens.size();
        int axis      = axis_arg->eval().at<int32_t>();

        // ensure split is made evenly if "num_split" is used
        assert(vector_as_input or lens[axis] % num_outputs == 0);

        auto split_size = lens[axis] / num_outputs;

        // push back first end point of slice
        if(vector_as_input)
        {
            slice_pos.push_back(splits[0]);
        }
        else
        {
            slice_pos.push_back(split_size);
        }

        // calculate remaining end points for each slice
        for(auto i = 1; i < num_outputs; i++)
        {
            if(vector_as_input)
            {
                splits[i] += splits[i - 1];
                slice_pos.push_back(splits[i]);
            }
            else
            {
                slice_pos.push_back((i + 1) * split_size);
            }
        }
        std::vector<instruction_ref> result;
        for(auto i = 0; i < num_outputs; i++)
        {
            op::slice op;
            op.axes = std::vector<int64_t>(num_dims);
            std::iota(op.axes.begin(), op.axes.end(), 0);
            op.starts = std::vector<int64_t>(num_dims, 0);
            op.ends   = std::vector<int64_t>(lens.begin(), lens.end());

            op.starts[axis] = slice_pos[i];
            op.ends[axis]   = slice_pos[i + 1];
931
            result.push_back(mm->add_instruction(op, input_arg));
kahmed10's avatar
kahmed10 committed
932
933
934
935
        }
        return result;
    }

Khalique's avatar
Khalique committed
936
937
    instruction_ref parse_squeeze(const std::string&,
                                  const attribute_map& attributes,
938
                                  std::vector<instruction_ref> args) const
939
940
    {
        op::squeeze op;
Khalique's avatar
Khalique committed
941
        auto input_dims = args[0]->get_shape().lens();
Khalique's avatar
Khalique committed
942
        auto axes       = attributes.at("squeeze_dims").list().i();
943
        copy(axes, std::back_inserter(op.axes));
Khalique's avatar
Khalique committed
944

945
946
        if(op.axes.empty()) // no squeeze_dims provided, remove any dim that equals 1
        {
Khalique's avatar
Khalique committed
947
            for(size_t i = 0; i < input_dims.size(); i++)
948
            {
Khalique's avatar
Khalique committed
949
                if(input_dims.at(i) == 1)
950
951
952
953
                {
                    op.axes.push_back(i);
                }
            }
954
        }
955
        return mm->add_instruction(op, make_contiguous(args[0]));
956
957
    }

Khalique's avatar
Khalique committed
958
959
960
    instruction_ref parse_stridedslice(const std::string&,
                                       const attribute_map& attributes,
                                       std::vector<instruction_ref> args)
961
962
    {
        op::slice op;
Khalique's avatar
Khalique committed
963
964
965
966
        auto starts              = args[1]->eval().get<int32_t>().to_vector();
        auto ends                = args[2]->eval().get<int32_t>().to_vector();
        auto l0                  = args[0];
        size_t num_axes          = l0->get_shape().lens().size();
967
        std::vector<size_t> axes = l0->get_shape().lens();
968

Khalique's avatar
Khalique committed
969
970
971
972
        op.starts = std::vector<int64_t>(starts.begin(), starts.end());
        op.ends   = std::vector<int64_t>(ends.begin(), ends.end());
        op.axes   = std::vector<int64_t>(num_axes);
        std::iota(op.axes.begin(), op.axes.end(), 0);
Khalique's avatar
Khalique committed
973
974
        uint32_t begin_mask       = 0;
        uint32_t end_mask         = 0;
975
        uint32_t shrink_axis_mask = 0;
Khalique's avatar
Khalique committed
976
        uint32_t bitwise_compare  = 1;
977
978
        std::vector<int64_t> squeeze_axes;

Khalique's avatar
Khalique committed
979
980
981
982
983
984
        if(contains(attributes, "begin_mask"))
            begin_mask = static_cast<uint32_t>(attributes.at("begin_mask").i());

        if(contains(attributes, "end_mask"))
            end_mask = static_cast<uint32_t>(attributes.at("end_mask").i());

985
        if(contains(attributes, "shrink_axis_mask"))
986
            shrink_axis_mask = static_cast<uint32_t>(attributes.at("shrink_axis_mask").i());
987

Khalique's avatar
Khalique committed
988
        std::vector<int64_t> begin_axes = get_axes_from_mask(num_axes, begin_mask);
Khalique's avatar
Khalique committed
989
        std::vector<int64_t> end_axes   = get_axes_from_mask(num_axes, end_mask);
Khalique's avatar
Khalique committed
990
991
992
993
994
995
996
997
998
999
1000
1001
1002

        for(size_t i = 0; i < num_axes; i++)
        {
            if(begin_axes.at(i) == 1)
            {
                op.starts.at(i) = 0;
            }
            if(end_axes.at(i) == 1)
            {
                op.ends.at(i) = axes.at(i);
            }
        }

1003
        auto l1 = mm->add_instruction(op, l0);
Khalique's avatar
Khalique committed
1004
        if(shrink_axis_mask == 0)
1005
            return l1;
Khalique's avatar
Khalique committed
1006

Khalique's avatar
Khalique committed
1007
        for(size_t i = 0; i < num_axes; i++)
1008
        {
1009
            // the LSB corresponds to axis 0 when determining which axes to squeeze
Khalique's avatar
Khalique committed
1010
            if(((shrink_axis_mask >> i) & bitwise_compare) == 1)
1011
1012
                squeeze_axes.push_back(i);
        }
Khalique's avatar
Khalique committed
1013

1014
        return mm->add_instruction(op::squeeze{squeeze_axes}, l1);
1015
1016
    }

1017
1018
1019
    instruction_ref parse_transpose(const std::string&,
                                    const attribute_map&,
                                    std::vector<instruction_ref> args) const
Khalique's avatar
Khalique committed
1020
1021
1022
1023
1024
    {
        auto perm = args[1]->eval().get<int32_t>().to_vector();
        op::transpose op;
        op.dims = std::vector<int64_t>(perm.begin(), perm.end());

1025
        return mm->add_instruction(op, args.front());
Khalique's avatar
Khalique committed
1026
1027
    }

Khalique's avatar
Khalique committed
1028
1029
1030
1031
1032
    void parse_graph(const tensorflow::GraphDef& graph)
    {
        nodes = get_nodes(graph, input_nodes);
        for(auto&& input : input_nodes)
        {
Khalique's avatar
Khalique committed
1033
            const std::string& name   = input.name();
Khalique's avatar
Khalique committed
1034
            attribute_map input_attrs = get_attributes(input);
Khalique's avatar
Khalique committed
1035
1036
            shape::type_t shape_type  = parse_type(input_attrs.at("dtype").type());
            std::vector<size_t> dims  = parse_dims(input_attrs.at("shape").shape());
1037
            if(is_nhwc and dims.size() >= 4)
1038
            {
1039
                reorder_data(dims);
1040
            }
1041
1042
1043
            std::transform(dims.begin(), dims.end(), dims.begin(), [&](auto dim) {
                return static_cast<int>(dim) <= 0 ? batch_size : dim;
            });
Khalique's avatar
Khalique committed
1044
            shape s            = shape{shape_type, dims};
1045
            instructions[name] = to_nhwc(mm->add_parameter(name, s));
Khalique's avatar
Khalique committed
1046
1047
1048
        }
        for(auto&& p : nodes)
        {
1049
            this->parse_node(p.first);
Khalique's avatar
Khalique committed
1050
        }
1051
1052
1053

        // Needs to add a ret instruction at the end of
        // the program
Khalique's avatar
Khalique committed
1054
1055
1056
1057
1058
1059
1060
    }

    void parse_node(const std::string& name)
    {
        if(instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
Khalique's avatar
Khalique committed
1061
1062
1063
            // assert ops ignored
            if(node.op() == "Assert" or contains(name, "Assert"))
                return;
Khalique's avatar
Khalique committed
1064
1065
1066
1067
            std::vector<instruction_ref> args;

            for(auto&& input : node.input())
            {
Khalique's avatar
Khalique committed
1068
1069
1070
                // control dependencies (signified by ^ before the name) are ignored
                if(contains(input, "^"))
                    continue;
Khalique's avatar
Khalique committed
1071
1072
                if(nodes.count(input) > 0)
                {
kahmed10's avatar
kahmed10 committed
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
                    std::string iname;
                    // input was from a node with multiple outputs
                    if(contains(input, ':'))
                    {
                        iname = input.substr(0, input.find(':'));
                    }
                    else
                    {
                        iname = get_name(nodes.at(input));
                    }
Khalique's avatar
Khalique committed
1083
1084
                    assert(name != iname);
                    this->parse_node(iname);
kahmed10's avatar
kahmed10 committed
1085
                    args.push_back(instructions.at(input));
Khalique's avatar
Khalique committed
1086
1087
1088
1089
1090
1091
                }
                else
                {
                    args.push_back(instructions.at(input));
                }
            }
kahmed10's avatar
kahmed10 committed
1092
1093

            std::vector<instruction_ref> result;
Khalique's avatar
Khalique committed
1094
1095
            if(ops.count(node.op()) == 0)
            {
1096
                result.push_back(mm->add_instruction(op::unknown{node.op()}, args));
Khalique's avatar
Khalique committed
1097
1098
1099
            }
            else
            {
kahmed10's avatar
kahmed10 committed
1100
1101
1102
1103
1104
1105
1106
1107
1108
                result = ops[node.op()](get_attributes(node), args);
            }

            assert(!result.empty());
            // First output has no ":" delimiter
            instructions[name] = result.front();
            for(size_t i = 1; i < result.size(); i++)
            {
                instructions[name + ":" + std::to_string(i)] = result.at(i);
Khalique's avatar
Khalique committed
1109
1110
1111
1112
            }
        }
    }

1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
    void parse_from(std::istream& is)
    {
        tensorflow::GraphDef graph;
        if(graph.ParseFromIstream(&is))
        {
            this->parse_graph(graph);
        }
        else
        {
            throw std::runtime_error("Failed reading tf file");
        }
    }

Khalique's avatar
Khalique committed
1126
1127
1128
    static attribute_map get_attributes(const tensorflow::NodeDef& node)
    {
        attribute_map result;
Khalique's avatar
Khalique committed
1129
        for(auto&& attr : node.attr())
Khalique's avatar
Khalique committed
1130
1131
1132
1133
1134
1135
        {
            result[attr.first] = attr.second;
        }
        return result;
    }

Khalique's avatar
Khalique committed
1136
    static std::string get_name(const tensorflow::NodeDef& node) { return node.name(); }
Khalique's avatar
Khalique committed
1137

Khalique's avatar
Khalique committed
1138
1139
    static node_map get_nodes(const tensorflow::GraphDef& graph,
                              std::vector<tensorflow::NodeDef>& input_nodes)
Khalique's avatar
Khalique committed
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
    {
        node_map result;
        for(auto&& node : graph.node())
        {
            auto node_name = get_name(node);
            // assume each node in graph has an associated name
            if(node_name.empty())
                MIGRAPHX_THROW("tf node with no name found");
            result[node_name] = node;
            if(node.op() == "Placeholder")
            {
                input_nodes.push_back(node);
            }
        }
        return result;
    }

    static shape::type_t parse_type(const tensorflow::DataType t)
    {
        shape::type_t shape_type{};
        switch(t)
        {
        case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
        case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
        case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
        case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
        case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
Paul's avatar
Paul committed
1167
1168
1169
1170
        case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
        case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
        case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
        case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
Paul's avatar
Paul committed
1171
        case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
Paul's avatar
Paul committed
1172
1173
1174

        case tensorflow::DataType::DT_INVALID:
        case tensorflow::DataType::DT_UINT8:
Khalique's avatar
Khalique committed
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
        case tensorflow::DataType::DT_STRING:
        case tensorflow::DataType::DT_COMPLEX64:
        case tensorflow::DataType::DT_BOOL:
        case tensorflow::DataType::DT_QINT8:
        case tensorflow::DataType::DT_QUINT8:
        case tensorflow::DataType::DT_QINT32:
        case tensorflow::DataType::DT_BFLOAT16:
        case tensorflow::DataType::DT_QINT16:
        case tensorflow::DataType::DT_QUINT16:
        case tensorflow::DataType::DT_COMPLEX128:
        case tensorflow::DataType::DT_RESOURCE:
        case tensorflow::DataType::DT_VARIANT:
Khalique's avatar
Khalique committed
1187
        // tf pb should not use these types
Paul's avatar
Paul committed
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
        case tensorflow::DataType::DT_FLOAT_REF:
        case tensorflow::DataType::DT_DOUBLE_REF:
        case tensorflow::DataType::DT_INT32_REF:
        case tensorflow::DataType::DT_UINT8_REF:
        case tensorflow::DataType::DT_INT16_REF:
        case tensorflow::DataType::DT_INT8_REF:
        case tensorflow::DataType::DT_STRING_REF:
        case tensorflow::DataType::DT_COMPLEX64_REF:
        case tensorflow::DataType::DT_INT64_REF:
        case tensorflow::DataType::DT_BOOL_REF:
        case tensorflow::DataType::DT_QINT8_REF:
        case tensorflow::DataType::DT_QUINT8_REF:
        case tensorflow::DataType::DT_QINT32_REF:
        case tensorflow::DataType::DT_BFLOAT16_REF:
        case tensorflow::DataType::DT_QINT16_REF:
        case tensorflow::DataType::DT_QUINT16_REF:
        case tensorflow::DataType::DT_UINT16_REF:
        case tensorflow::DataType::DT_COMPLEX128_REF:
        case tensorflow::DataType::DT_HALF_REF:
        case tensorflow::DataType::DT_RESOURCE_REF:
        case tensorflow::DataType::DT_VARIANT_REF:
        case tensorflow::DataType::DT_UINT32_REF:
        case tensorflow::DataType::DT_UINT64_REF:
Paul's avatar
Paul committed
1211
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
Khalique's avatar
Khalique committed
1212
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
Khalique's avatar
Khalique committed
1213
1214
1215
1216
        }
        return shape_type;
    }

Khalique's avatar
Khalique committed
1217
    static literal parse_tensor(const tensorflow::TensorProto& t)
Khalique's avatar
Khalique committed
1218
1219
    {
        std::vector<size_t> dims = parse_dims(t.tensor_shape());
1220
        size_t shape_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>());
Khalique's avatar
Khalique committed
1221
1222
        if(!t.tensor_content().empty()) // has raw data
        {
Khalique's avatar
Khalique committed
1223
            const std::string& s = t.tensor_content();
Khalique's avatar
Khalique committed
1224
1225
            switch(t.dtype())
            {
Khalique's avatar
Khalique committed
1226
1227
            case tensorflow::DataType::DT_FLOAT:
                return literal{{shape::float_type, dims}, s.data()};
Paul's avatar
Paul committed
1228
            case tensorflow::DataType::DT_BOOL:
1229
            case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1230
1231
            case tensorflow::DataType::DT_UINT16:
            case tensorflow::DataType::DT_INT16:
1232
                return literal{{shape::int16_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1233
1234
1235
1236
            case tensorflow::DataType::DT_INT32:
                return literal{{shape::int32_type, dims}, s.data()};
            case tensorflow::DataType::DT_INT64:
                return literal{{shape::int64_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1237
            case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
Khalique's avatar
Khalique committed
1238
1239
            case tensorflow::DataType::DT_DOUBLE:
                return literal{{shape::double_type, dims}, s.data()};
Paul's avatar
Paul committed
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
            case tensorflow::DataType::DT_INVALID:
            case tensorflow::DataType::DT_UINT8:
            case tensorflow::DataType::DT_STRING:
            case tensorflow::DataType::DT_UINT32:
            case tensorflow::DataType::DT_UINT64:
            case tensorflow::DataType::DT_COMPLEX64:
            case tensorflow::DataType::DT_COMPLEX128:
            case tensorflow::DataType::DT_QINT8:
            case tensorflow::DataType::DT_QUINT8:
            case tensorflow::DataType::DT_QINT32:
            case tensorflow::DataType::DT_BFLOAT16:
            case tensorflow::DataType::DT_QINT16:
            case tensorflow::DataType::DT_QUINT16:
            case tensorflow::DataType::DT_RESOURCE:
            case tensorflow::DataType::DT_VARIANT:
            case tensorflow::DataType::DT_FLOAT_REF:
            case tensorflow::DataType::DT_DOUBLE_REF:
            case tensorflow::DataType::DT_INT32_REF:
            case tensorflow::DataType::DT_UINT8_REF:
            case tensorflow::DataType::DT_INT16_REF:
            case tensorflow::DataType::DT_INT8_REF:
            case tensorflow::DataType::DT_STRING_REF:
            case tensorflow::DataType::DT_COMPLEX64_REF:
            case tensorflow::DataType::DT_INT64_REF:
            case tensorflow::DataType::DT_BOOL_REF:
            case tensorflow::DataType::DT_QINT8_REF:
            case tensorflow::DataType::DT_QUINT8_REF:
            case tensorflow::DataType::DT_QINT32_REF:
            case tensorflow::DataType::DT_BFLOAT16_REF:
            case tensorflow::DataType::DT_QINT16_REF:
            case tensorflow::DataType::DT_QUINT16_REF:
            case tensorflow::DataType::DT_UINT16_REF:
            case tensorflow::DataType::DT_COMPLEX128_REF:
            case tensorflow::DataType::DT_HALF_REF:
            case tensorflow::DataType::DT_RESOURCE_REF:
            case tensorflow::DataType::DT_VARIANT_REF:
            case tensorflow::DataType::DT_UINT32_REF:
            case tensorflow::DataType::DT_UINT64_REF:
Khalique's avatar
Khalique committed
1278
1279
1280
            case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
            case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
                throw std::runtime_error("");
Khalique's avatar
Khalique committed
1281
1282
1283
1284
1285
1286
            }
            MIGRAPHX_THROW("Invalid tensor type");
        }
        switch(t.dtype())
        {
        case tensorflow::DataType::DT_FLOAT:
Khalique's avatar
Khalique committed
1287
1288
            return create_literal(
                shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
Khalique's avatar
Khalique committed
1289
        case tensorflow::DataType::DT_INT8:
1290
            return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1291
        case tensorflow::DataType::DT_UINT16:
1292
            return create_literal(shape::uint16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1293
        case tensorflow::DataType::DT_INT16:
1294
            return create_literal(shape::int16_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1295
        case tensorflow::DataType::DT_INT32:
1296
            return create_literal(shape::int32_type, dims, get_data_vals(t.int_val(), shape_size));
Khalique's avatar
Khalique committed
1297
        case tensorflow::DataType::DT_INT64:
Khalique's avatar
Khalique committed
1298
1299
            return create_literal(
                shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
Khalique's avatar
Khalique committed
1300
        case tensorflow::DataType::DT_BOOL:
1301
            return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
Khalique's avatar
Khalique committed
1302
        case tensorflow::DataType::DT_HALF:
Khalique's avatar
Khalique committed
1303
        {
1304
1305
            std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
            std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
Khalique's avatar
Khalique committed
1306
1307
1308
1309
1310
            std::vector<half> data_half;
            std::transform(data_uint16.begin(),
                           data_uint16.end(),
                           std::back_inserter(data_half),
                           [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
1311
            return create_literal(shape::half_type, dims, data_half);
Khalique's avatar
Khalique committed
1312
        }
Khalique's avatar
Khalique committed
1313
        case tensorflow::DataType::DT_DOUBLE:
Khalique's avatar
Khalique committed
1314
            return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
Paul's avatar
Paul committed
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
        case tensorflow::DataType::DT_INVALID:
        case tensorflow::DataType::DT_UINT8:
        case tensorflow::DataType::DT_STRING:
        case tensorflow::DataType::DT_UINT32:
        case tensorflow::DataType::DT_UINT64:
        case tensorflow::DataType::DT_COMPLEX64:
        case tensorflow::DataType::DT_COMPLEX128:
        case tensorflow::DataType::DT_QINT8:
        case tensorflow::DataType::DT_QUINT8:
        case tensorflow::DataType::DT_QINT32:
        case tensorflow::DataType::DT_BFLOAT16:
        case tensorflow::DataType::DT_QINT16:
        case tensorflow::DataType::DT_QUINT16:
        case tensorflow::DataType::DT_RESOURCE:
        case tensorflow::DataType::DT_VARIANT:
        case tensorflow::DataType::DT_FLOAT_REF:
        case tensorflow::DataType::DT_DOUBLE_REF:
        case tensorflow::DataType::DT_INT32_REF:
        case tensorflow::DataType::DT_UINT8_REF:
        case tensorflow::DataType::DT_INT16_REF:
        case tensorflow::DataType::DT_INT8_REF:
        case tensorflow::DataType::DT_STRING_REF:
        case tensorflow::DataType::DT_COMPLEX64_REF:
        case tensorflow::DataType::DT_INT64_REF:
        case tensorflow::DataType::DT_BOOL_REF:
        case tensorflow::DataType::DT_QINT8_REF:
        case tensorflow::DataType::DT_QUINT8_REF:
        case tensorflow::DataType::DT_QINT32_REF:
        case tensorflow::DataType::DT_BFLOAT16_REF:
        case tensorflow::DataType::DT_QINT16_REF:
        case tensorflow::DataType::DT_QUINT16_REF:
        case tensorflow::DataType::DT_UINT16_REF:
        case tensorflow::DataType::DT_COMPLEX128_REF:
        case tensorflow::DataType::DT_HALF_REF:
        case tensorflow::DataType::DT_RESOURCE_REF:
        case tensorflow::DataType::DT_VARIANT_REF:
        case tensorflow::DataType::DT_UINT32_REF:
        case tensorflow::DataType::DT_UINT64_REF:
Khalique's avatar
Khalique committed
1353
1354
1355
        case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
        case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
            throw std::runtime_error("");
Khalique's avatar
Khalique committed
1356
1357
1358
1359
        }
        MIGRAPHX_THROW("Invalid tensor type");
    }

1360
    template <class T>
Khalique's avatar
Khalique committed
1361
    static std::vector<T> get_data_vals(const google::protobuf::RepeatedField<T>& data,
1362
                                        const size_t& shape_size)
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
    {
        std::vector<T> data_vals(shape_size);
        // check if shape has enough data values given existing fields
        if(data.size() == 1)
        {
            std::fill(data_vals.begin(), data_vals.end(), data[0]);
        }
        else
            copy(data.begin(), data.end(), std::back_inserter(data_vals));
        return data_vals;
    }

Khalique's avatar
Khalique committed
1375
1376
1377
1378
    static std::vector<size_t> parse_dims(const tensorflow::TensorShapeProto& s)
    {
        std::vector<size_t> dims;
        auto input_dims = s.dim();
Khalique's avatar
Khalique committed
1379
1380
1381
        std::transform(input_dims.begin(),
                       input_dims.end(),
                       std::back_inserter(dims),
Paul's avatar
Paul committed
1382
                       [](const tensorflow::TensorShapeProto_Dim& dim) { return dim.size(); });
Khalique's avatar
Khalique committed
1383
1384
        return dims;
    }
1385
1386

    template <class T>
Khalique's avatar
Khalique committed
1387
    static literal
1388
    create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
1389
    {
Khalique's avatar
Khalique committed
1390
        // assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
1391
        if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
1392
            return literal{{shape_type}, data};
1393
1394
        return literal{{shape_type, dims}, data};
    }
Khalique's avatar
Khalique committed
1395
1396
};

1397
program parse_tf(const std::string& name, tf_options options)
Khalique's avatar
Khalique committed
1398
1399
1400
{
    std::fstream input(name.c_str(), std::ios::in | std::ios::binary);
    tf_parser parser;
1401
1402
    parser.is_nhwc    = options.is_nhwc;
    parser.batch_size = options.batch_size;
Khalique's avatar
Khalique committed
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417

#ifndef NDEBUG
    // Log the program when it can't be parsed
    try
    {
        parser.parse_from(input);
    }
    catch(...)
    {
        std::cerr << parser.prog << std::endl;
        throw;
    }
#else
    parser.parse_from(input);
#endif
1418
    parser.to_nchw(std::prev(parser.mm->end()));
Khalique's avatar
Khalique committed
1419
1420
1421
1422
1423
    return std::move(parser.prog);
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx