"docs/en_US/LocalMode.md" did not exist on "45c6508eec59fbc40255d9703b6c606c76b1d842"
parse_conv.cpp 3.47 KB
Newer Older
kahmed10's avatar
kahmed10 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <migraphx/tf/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pad_calc.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/make_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace tf {

struct parse_conv : op_parser<parse_conv>
{
    bool transpose() const { return true; }
    std::vector<op_desc> operators() const { return {{"Conv2D"}}; }

    instruction_ref parse(const op_desc& /*opd*/,
                          const tf_parser& parser,
                          tf_parser::node_info info,
                          std::vector<instruction_ref> args) const
    {
        op::convolution op;
        if(contains(info.attributes, "strides"))
        {
25
            std::vector<int> stride;
kahmed10's avatar
kahmed10 committed
26
27
28
29
30
31
32
33
34
35
36
            copy(info.attributes.at("strides").list().i(), std::back_inserter(stride));
            parser.reorder_data(stride);
            if(stride.size() != 4)
            {
                MIGRAPHX_THROW("strides should have 4 values");
            }
            op.stride[0] = stride[2];
            op.stride[1] = stride[3];
        }
        if(contains(info.attributes, "dilations"))
        {
37
            std::vector<int> dilation;
kahmed10's avatar
kahmed10 committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
            copy(info.attributes.at("dilations").list().i(), std::back_inserter(dilation));
            parser.reorder_data(dilation);
            if(dilation.size() != 4)
            {
                MIGRAPHX_THROW("dilation should have 4 values");
            }
            op.dilation[0] = dilation[2];
            op.dilation[1] = dilation[3];
        }

        auto weights = parser.to_kcxy(args[1]);
        auto l0      = args[0];
        if(contains(info.attributes, "padding"))
        {
            const std::string& pad_mode = info.attributes.at("padding").s();
            if(pad_mode.find("SAME") != std::string::npos)
            {
Shucai Xiao's avatar
Shucai Xiao committed
55
                op.padding_mode              = op::padding_mode_t::same;
56
57
58
                std::vector<int> weight_dims = weights->get_shape().lens();
                int weight_h                 = weight_dims[2];
                int weight_w                 = weight_dims[3];
kahmed10's avatar
kahmed10 committed
59
60
61
62
63
64

                auto input_dims = l0->get_shape().lens();
                std::vector<int64_t> pads(input_dims.size());
                calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
                calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);

65
                op.padding = std::vector<int>(pads.begin(), pads.end());
kahmed10's avatar
kahmed10 committed
66
67
68
69
70
71
72
            }
            else if(pad_mode.find("VALID") != std::string::npos)
            {
                op.padding_mode = op::padding_mode_t::valid;
            }
            else if(pad_mode.find("EXPLICIT") != std::string::npos)
            {
73
                std::vector<int> padding;
kahmed10's avatar
kahmed10 committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
                copy(info.attributes.at("explicit_paddings").list().i(),
                     std::back_inserter(padding));
                if(padding.size() != 4)
                {
                    MIGRAPHX_THROW("padding should have 4 values");
                }
                if(padding[0] != padding[2] || padding[1] != padding[3])
                {
                    MIGRAPHX_THROW("migraphx does not support asymetric padding");
                }
                op.padding[0] = padding[0];
                op.padding[1] = padding[1];
            }
        }
        return info.add_instruction(op, {l0, weights});
    }
};

} // namespace tf
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx