parse_depthwiseconv.cpp 4.06 KB
Newer Older
kahmed10's avatar
kahmed10 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/make_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {

struct parse_depthwiseconv : op_parser<parse_depthwiseconv>
{
    bool transpose() const { return true; }
    std::vector<op_desc> operators() const { return {{"DepthwiseConv2dNative"}}; }

    instruction_ref parse(const op_desc& /*opd*/,
                          const tf_parser& parser,
                          tf_parser::node_info info,
                          std::vector<instruction_ref> args) const
    {
        op::convolution op;
23
        int num_channels = args[0]->get_shape().lens()[1];
kahmed10's avatar
kahmed10 committed
24
25
26
27
        op.group            = num_channels;

        if(contains(info.attributes, "strides"))
        {
28
            std::vector<int> stride;
kahmed10's avatar
kahmed10 committed
29
30
31
32
33
34
35
36
37
38
39
40
41
            copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
            parser.reorder_data(stride);
            if(stride.size() != 4)
            {
                MIGRAPHX_THROW("strides should have 4 values");
            }
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
        }

        auto weights = parser.to_kcxy(args[1]);
        if(contains(info.attributes, "dilations"))
        {
42
            std::vector<int> dilation;
kahmed10's avatar
kahmed10 committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
            copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation));
            parser.reorder_data(dilation);
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
        }

        auto l0 = args[0];
        if(contains(info.attributes, "padding"))
        {
            const std::string& pad_mode = info.attributes.at("padding").s();

            if(pad_mode.find("SAME") != std::string::npos)
            {
                op.padding_mode                 = op::padding_mode_t::same;
61
62
63
                std::vector<int> weight_dims = weights->get_shape().lens();
                int weight_h                 = weight_dims[2];
                int weight_w                 = weight_dims[3];
kahmed10's avatar
kahmed10 committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);

                if(pads[0] != pads[2] || pads[1] != pads[3])
                {
                    std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
                    l0 = info.add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
                }
                else
                {
                    op.padding[0] = pads[0];
                    op.padding[1] = pads[1];
                }
            }
            else if(pad_mode.find("VALID") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::valid;
            }
        }

        std::vector<int64_t> new_weights_shape;
        copy(weights->get_shape().lens(), std::back_inserter(new_weights_shape));

        // weight format is (out_channels, in_channels, h, w), but in depthwise_conv,
        // out_channels is equal to the multiplier. Adjust by inserting a reshape and
        // setting in_channels to 1
        int64_t multiplier   = new_weights_shape[0];
        int64_t out_channels = num_channels * multiplier;
        new_weights_shape[0] = out_channels;
        new_weights_shape[1] = 1;
        // Make sure weights are contiguous before doing reshape
        auto new_weights = info.add_instruction(make_op("reshape", {{"dims", new_weights_shape}}),
                                                info.make_contiguous(weights));

        return info.add_instruction(op, {l0, new_weights});
    }
};

} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx