parse_if.cpp 7.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Shucai Xiao's avatar
Shucai Xiao committed
24
#include <migraphx/instruction_ref.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
25
26
27
28
29
30
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
31
32
#include <migraphx/reduce_dims.hpp>
#include <algorithm>
Shucai Xiao's avatar
Shucai Xiao committed
33
34
35
36
37

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

Ted Themistokleous's avatar
Ted Themistokleous committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
inline bool all_but_last_dims_equal(const std::vector<size_t>& lens_a,
                                    const std::vector<size_t>& lens_b)
{
    if(lens_a.size() <= lens_b.size())
    {
        return std::equal(lens_a.begin(), lens_a.end(), lens_b.begin());
    }
    else
    {
        return std::equal(lens_b.begin(), lens_b.end(), lens_a.begin());
    }
};

void unsqueeze_last_op(module_ref mdl, int index, const std::vector<size_t>& out_shape)
{
    auto convert_ins =
        mdl->insert_instruction(std::prev(mdl->end()),
                                make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
                                std::prev(mdl->end())->inputs().at(index));
    mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), convert_ins);
}

Shucai Xiao's avatar
Shucai Xiao committed
60
61
62
63
64
65
66
67
68
struct parse_if : op_parser<parse_if>
{
    std::vector<op_desc> operators() const { return {{"If"}}; }

    std::vector<instruction_ref> parse(const op_desc& /*opd*/,
                                       onnx_parser& parser,
                                       const onnx_parser::node_info& info,
                                       std::vector<instruction_ref> args) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
69
70
        const auto& then_graph = info.attributes.at("then_branch").g();
        const auto& else_graph = info.attributes.at("else_branch").g();
Shucai Xiao's avatar
Shucai Xiao committed
71

Shucai Xiao's avatar
Shucai Xiao committed
72
        if(args.front()->get_shape().elements() != 1)
Shucai Xiao's avatar
Shucai Xiao committed
73
        {
74
75
            MIGRAPHX_THROW("PARSE_IF: " + info.name +
                           " condition input can have only one element!");
Shucai Xiao's avatar
Shucai Xiao committed
76
77
        }

Shucai Xiao's avatar
Shucai Xiao committed
78
79
        std::string then_name = info.name + "_if";
        module_ref then_mdl   = parser.prog.create_module(then_name);
Shucai Xiao's avatar
Shucai Xiao committed
80

Shucai Xiao's avatar
Shucai Xiao committed
81
82
        std::string else_name = info.name + "_else";
        module_ref else_mdl   = parser.prog.create_module(else_name);
Shucai Xiao's avatar
Shucai Xiao committed
83

Shucai Xiao's avatar
Shucai Xiao committed
84
85
        // parse the then sub_graph
        parser.parse_graph(then_mdl, then_graph);
Shucai Xiao's avatar
Shucai Xiao committed
86

Shucai Xiao's avatar
Shucai Xiao committed
87
88
        // parse_the else sub_graph
        parser.parse_graph(else_mdl, else_graph);
Shucai Xiao's avatar
Shucai Xiao committed
89

Shucai Xiao's avatar
Shucai Xiao committed
90
91
        auto then_out_shapes = then_mdl->get_output_shapes();
        auto else_out_shapes = else_mdl->get_output_shapes();
92

93
94
        auto throw_shapes = [&]() {
            MIGRAPHX_THROW("PARSE_IF: " + info.name +
95
                           " then and else sub_graphs must have compatible shapes ");
96
97
98
99
100
101
        };

        if(then_out_shapes.size() != else_out_shapes.size())
        {
            throw_shapes();
        }
102

Ted Themistokleous's avatar
Ted Themistokleous committed
103
104
        // Add checks for each output shape
        for(int i = 0; i < then_out_shapes.size(); i++)
Shucai Xiao's avatar
Shucai Xiao committed
105
        {
Ted Themistokleous's avatar
Ted Themistokleous committed
106
107
108
            const auto& then_out_shape = then_out_shapes.at(i);
            const auto& else_out_shape = else_out_shapes.at(i);

Ted Themistokleous's avatar
Ted Themistokleous committed
109
            // Must have the same type for both if/else blocks by onnx spec
Ted Themistokleous's avatar
Ted Themistokleous committed
110
            if(then_out_shape.type() != else_out_shape.type())
111
            {
Ted Themistokleous's avatar
Ted Themistokleous committed
112
113
                MIGRAPHX_THROW("PARSE_IF: " + info.name +
                               " then and else sub_grahps must have same output type! " +
Ted Themistokleous's avatar
Ted Themistokleous committed
114
115
                               then_out_shape.type_string() + " vs " +
                               else_out_shape.type_string());
116
117
            }

Ted Themistokleous's avatar
Ted Themistokleous committed
118
            if(not then_out_shape.dynamic() and not else_out_shape.dynamic())
119
            {
Ted Themistokleous's avatar
Ted Themistokleous committed
120
121
                auto then_lens = then_out_shape.lens();
                auto else_lens = else_out_shape.lens();
122

Ted Themistokleous's avatar
Ted Themistokleous committed
123
124
                // Throw error if both branches have zero output shapes. Not possible for static
                // inputs
Ted Themistokleous's avatar
Ted Themistokleous committed
125
                if(then_lens.empty() and else_lens.empty())
Ted Themistokleous's avatar
Ted Themistokleous committed
126
127
128
                {
                    throw_shapes();
                }
129

Ted Themistokleous's avatar
Ted Themistokleous committed
130
                auto handle_empty_branch = [](module_ref& mdl, const shape& out_shape) {
131
132
133
                    shape gen_shape(out_shape.type(), out_shape.lens(), out_shape.strides());
                    auto literal_ins = mdl->add_literal(gen_shape, gen_shape.lens());
                    mdl->replace_return({literal_ins});
Ted Themistokleous's avatar
Ted Themistokleous committed
134
135
136
137
138
139
140
                    return out_shape.lens();
                };

                // Handle one empty branch by setting output identical to the other
                // need to update the then_shape before we do further checks
                if(then_lens.empty())
                {
Ted Themistokleous's avatar
Ted Themistokleous committed
141
                    then_lens = handle_empty_branch(then_mdl, else_out_shape);
Ted Themistokleous's avatar
Ted Themistokleous committed
142
143
144
                }
                else if(else_lens.empty())
                {
Ted Themistokleous's avatar
Ted Themistokleous committed
145
                    else_lens = handle_empty_branch(else_mdl, then_out_shape);
Ted Themistokleous's avatar
Ted Themistokleous committed
146
147
148
                }

                // check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
Ted Themistokleous's avatar
Ted Themistokleous committed
149
                int rank_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));
Ted Themistokleous's avatar
Ted Themistokleous committed
150

Ted Themistokleous's avatar
Ted Themistokleous committed
151
                if(rank_delta == 1)
152
                {
Ted Themistokleous's avatar
Ted Themistokleous committed
153
154
155
156
157
                    // make sure dims are equivalent in static shapes
                    if(not all_but_last_dims_equal(then_lens, else_lens))
                    {
                        throw_shapes();
                    }
158

Ted Themistokleous's avatar
Ted Themistokleous committed
159
160
                    auto last_then = then_lens.back();
                    auto last_else = else_lens.back();
161

Ted Themistokleous's avatar
Ted Themistokleous committed
162
163
164
                    // Find which dim to unsqueeze
                    if((then_lens.size() < else_lens.size()) && (last_else == 1))
                    {
Ted Themistokleous's avatar
Ted Themistokleous committed
165
                        unsqueeze_last_op(then_mdl, i, else_lens);
Ted Themistokleous's avatar
Ted Themistokleous committed
166
167
168
                    }
                    else if((then_lens.size() > else_lens.size()) && (last_then == 1))
                    {
Ted Themistokleous's avatar
Ted Themistokleous committed
169
                        unsqueeze_last_op(else_mdl, i, then_lens);
Ted Themistokleous's avatar
Ted Themistokleous committed
170
                    }
171
                }
Ted Themistokleous's avatar
Ted Themistokleous committed
172
                else if(rank_delta > 1)
173
                {
Ted Themistokleous's avatar
Ted Themistokleous committed
174
                    throw_shapes();
175
176
                }
            }
Shucai Xiao's avatar
Shucai Xiao committed
177
        }
Shucai Xiao's avatar
Shucai Xiao committed
178

Shucai Xiao's avatar
Shucai Xiao committed
179
180
181
        auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
        auto out_s  = if_ret->get_shape();
        assert(out_s.type() == shape::tuple_type);
Shucai Xiao's avatar
Shucai Xiao committed
182

Shucai Xiao's avatar
Shucai Xiao committed
183
184
185
186
187
188
        const auto& vec_shapes = out_s.sub_shapes();
        std::vector<instruction_ref> out_inss;
        for(std::size_t i = 0; i < vec_shapes.size(); ++i)
        {
            auto ret = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), if_ret);
            out_inss.push_back(ret);
Shucai Xiao's avatar
Shucai Xiao committed
189
        }
Shucai Xiao's avatar
Shucai Xiao committed
190
191

        return out_inss;
Shucai Xiao's avatar
Shucai Xiao committed
192
193
194
195
196
197
    }
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx