parse_clip.cpp 2.44 KB
Newer Older
Paul Fultz II's avatar
Paul Fultz II committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_clip : op_parser<parse_clip>
{
    std::vector<op_desc> operators() const { return {{"Clip"}}; }

    instruction_ref parse(const op_desc& /*opd*/,
                          const onnx_parser& parser,
                          onnx_parser::node_info info,
                          std::vector<instruction_ref> args) const
    {
        auto input_lens = args[0]->get_shape().lens();
        instruction_ref min_arg;
        instruction_ref max_arg;
        bool min_used = false;
        bool max_used = false;

        if(args.size() == 3 and args[2]->name() != "undefined")
        {
            max_arg  = args[2];
            max_used = true;
        }

        if(args.size() >= 2 and args[1]->name() != "undefined")
        {
            min_arg  = args[1];
            min_used = true;
        }
        // if using previous opset for attributes
        else if(contains(info.attributes, "min") and contains(info.attributes, "max"))
        {

            float min_val = parser.parse_value(info.attributes.at("min")).at<float>();
            float max_val = parser.parse_value(info.attributes.at("max")).at<float>();
            min_arg       = info.add_literal(min_val);
            max_arg       = info.add_literal(max_val);
            min_used      = true;
            max_used      = true;
        }

        if(min_used)
        {
50
            min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
Paul Fultz II's avatar
Paul Fultz II committed
51
52
53
54
55
                                           min_arg);
        }

        if(max_used)
        {
56
            max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
Paul Fultz II's avatar
Paul Fultz II committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
                                           max_arg);
        }

        if(min_used and max_used)
        {
            return info.add_instruction(make_op("clip"), args[0], min_arg, max_arg);
        }
        else if(max_used)
        {
            return info.add_instruction(make_op("min"), args[0], max_arg);
        }
        else if(min_used)
        {
            return info.add_instruction(make_op("max"), args[0], min_arg);
        }
        else
        {
            return info.add_instruction(make_op("identity"), args[0]);
        }
    }
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx