parse_convolution.cpp 3.34 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/onnx/conv.hpp>
#include <migraphx/onnx/padding.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/make_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_convolution : op_parser<parse_convolution>
{
    std::vector<op_desc> operators() const
    {
        return {{"Conv", "convolution"}, {"ConvInteger", "quant_convolution"}};
    }

    instruction_ref parse(const op_desc& opd,
                          const onnx_parser& parser,
                          onnx_parser::node_info info,
                          std::vector<instruction_ref> args) const
    {
charlie's avatar
charlie committed
27
28
29
30
31
32
        auto op       = make_op(opd.op_name);
        auto values   = op.to_value();
        auto l0       = args[0];
        auto weights  = args[1];
        auto l0_shape = l0->get_shape();
        auto in_lens  = l0_shape.max_lens();
Paul Fultz II's avatar
Paul Fultz II committed
33
34
35
        assert(in_lens.size() > 2);
        auto kdims = in_lens.size() - 2;

charlie's avatar
charlie committed
36
        // ensure pads available only when auto_pad is "NOT_SET"
Paul Fultz II's avatar
Paul Fultz II committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        check_padding_mode(info, "CONV");

        if(contains(info.attributes, "strides"))
        {
            values["stride"].clear();
            copy(info.attributes["strides"].ints(), std::back_inserter(values["stride"]));
            check_attr_sizes(kdims, values["stride"].size(), "PARSE_CONV: inconsistent strides");
        }
        if(contains(info.attributes, "dilations"))
        {
            values["dilation"].clear();
            copy(info.attributes["dilations"].ints(), std::back_inserter(values["dilation"]));
            check_attr_sizes(
                kdims, values["dilation"].size(), "PARSE_CONV: inconsistent dilations");
        }

        std::vector<int64_t> padding;
        if(contains(info.attributes, "pads"))
        {
            values["padding"].clear();
            copy(info.attributes["pads"].ints(), std::back_inserter(padding));
            check_attr_sizes(kdims, padding.size() / 2, "PARSE_CONV: inconsistent paddings");
        }

        if(contains(info.attributes, "auto_pad"))
        {
charlie's avatar
charlie committed
63
            auto weight_lens = weights->get_shape().max_lens();
Paul Fultz II's avatar
Paul Fultz II committed
64
65
66
67
68
69
70
71
72
73
74
75
76
            std::vector<std::size_t> k_lens(weight_lens.begin() + 2, weight_lens.end());
            cal_auto_padding_size(info,
                                  values,
                                  k_lens,
                                  values["dilation"].to_vector<std::size_t>(),
                                  in_lens,
                                  padding);
            auto auto_pad = info.attributes["auto_pad"].s();
            if(auto_pad.find("SAME") != std::string::npos)
            {
                values["padding_mode"] = to_value(op::padding_mode_t::same);
            }
        }
kahmed10's avatar
kahmed10 committed
77
        values["padding"] = std::vector<size_t>(padding.begin(), padding.end());
Paul Fultz II's avatar
Paul Fultz II committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

        if(contains(info.attributes, "group"))
        {
            values["group"] = parser.parse_value(info.attributes.at("group")).at<int>();
        }

        recalc_conv_attributes(values, kdims);

        op.from_value(values);
        auto l1 = info.add_instruction(op, l0, args[1]);
        return info.add_bias(args, l1, 1);
    }
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx