parse_if.cpp 2.91 KB
Newer Older
Shucai Xiao's avatar
Shucai Xiao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

struct parse_if : op_parser<parse_if>
{
    std::vector<op_desc> operators() const { return {{"If"}}; }

    std::vector<instruction_ref> parse(const op_desc& /*opd*/,
                                       onnx_parser& parser,
                                       const onnx_parser::node_info& info,
                                       std::vector<instruction_ref> args) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
21
22
        const auto& then_graph = info.attributes.at("then_branch").g();
        const auto& else_graph = info.attributes.at("else_branch").g();
Shucai Xiao's avatar
Shucai Xiao committed
23

Shucai Xiao's avatar
Shucai Xiao committed
24
        if(args.front()->get_shape().elements() != 1)
Shucai Xiao's avatar
Shucai Xiao committed
25
26
27
28
        {
            MIGRAPHX_THROW("PARSE_IF: condition input can have only one element!");
        }

Shucai Xiao's avatar
Shucai Xiao committed
29
30
31
        migraphx::argument cond_arg = args.front()->eval();
        // cond is not constant, need to create sub_modules
        if(cond_arg.empty())
Shucai Xiao's avatar
Shucai Xiao committed
32
        {
Shucai Xiao's avatar
Shucai Xiao committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
            std::string then_name = info.name + "_if";
            module_ref then_mdl   = parser.prog.create_module(then_name);

            std::string else_name = info.name + "_else";
            module_ref else_mdl   = parser.prog.create_module(else_name);

            // parse the then sub_graph
            parser.parse_graph(then_mdl, then_graph);

            // parse_the else sub_graph
            parser.parse_graph(else_mdl, else_graph);

            auto then_out_shapes = then_mdl->get_output_shapes();
            auto else_out_shapes = else_mdl->get_output_shapes();
            if(not std::equal(then_out_shapes.begin(),
                              then_out_shapes.end(),
                              else_out_shapes.begin(),
                              else_out_shapes.end()))
            {
                MIGRAPHX_THROW("PARSE_IF: then and else sub_grahps must have same output shapes!");
            }

            auto ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});

            return {ret};
Shucai Xiao's avatar
Shucai Xiao committed
58
59
60
        }
        else
        {
Shucai Xiao's avatar
Shucai Xiao committed
61
62
63
64
65
66
67
68
69
70
71
            auto* mod = info.mod;
            // then branch
            if(cond_arg.at<bool>())
            {
                parser.parse_graph(mod, then_graph);
            }
            // else branch
            else
            {
                parser.parse_graph(mod, else_graph);
            }
Shucai Xiao's avatar
Shucai Xiao committed
72

Shucai Xiao's avatar
Shucai Xiao committed
73
74
75
76
77
78
            // inputs of the return instruction are that of the output of the
            // if instruction
            instruction_ref ret_ins = std::prev(mod->end());
            auto outputs            = ret_ins->inputs();
            assert(ret_ins->name() == "@return");
            mod->remove_instruction(ret_ins);
Shucai Xiao's avatar
Shucai Xiao committed
79

Shucai Xiao's avatar
Shucai Xiao committed
80
81
            return outputs;
        }
Shucai Xiao's avatar
Shucai Xiao committed
82
83
84
85
86
87
    }
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx