parse_quantizelinear.cpp 2.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_quantizelinear : op_parser<parse_quantizelinear>
{
    std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; }

    instruction_ref parse(const op_desc& opd,
                          const onnx_parser& /*parser*/,
                          const onnx_parser::node_info& info,
turneram's avatar
turneram committed
18
                          const std::vector<instruction_ref>& args) const
19
20
21
22
23
24
    {
        int axis = 1;
        if(contains(info.attributes, "axis"))
            axis = info.attributes.at("axis").i();

        auto input_lens = args[0]->get_shape().lens();
turneram's avatar
turneram committed
25
        auto n_dim      = input_lens.size();
26

turneram's avatar
turneram committed
27
28
        instruction_ref y_scale;
        if(args[1]->get_shape().elements() != 1)
29
        {
turneram's avatar
turneram committed
30
31
            auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
            y_scale         = info.add_instruction(
32
                make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}), args[1]);
turneram's avatar
turneram committed
33
34
35
        }
        else
        {
36
            y_scale = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
turneram's avatar
turneram committed
37
                                           args[1]);
38
39
        }

turneram's avatar
turneram committed
40
        if(args.size() == 3)
41
        {
turneram's avatar
turneram committed
42
43
            auto y_zero_point = args[2];
            if(y_zero_point->get_shape().elements() != 1)
44
            {
turneram's avatar
turneram committed
45
46
                auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
                y_zero_point    = info.add_instruction(
47
                    make_op("broadcast", {{"axis", tuned_axis}, {"out_lens", input_lens}}),
turneram's avatar
turneram committed
48
49
50
51
52
                    y_zero_point);
            }
            else
            {
                y_zero_point = info.add_instruction(
53
                    make_op("multibroadcast", {{"out_lens", input_lens}}), y_zero_point);
54
55
            }

turneram's avatar
turneram committed
56
            return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
Shucai Xiao's avatar
Shucai Xiao committed
57
58
        }

turneram's avatar
turneram committed
59
        return info.add_instruction(make_op("quantizelinear"), args[0], y_scale);
60
61
62
63
64
65
    }
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx