parse_if.cpp 8.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
Shucai Xiao's avatar
Shucai Xiao committed
24
#include <migraphx/instruction_ref.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
25
26
27
28
29
30
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
31
32
#include <migraphx/reduce_dims.hpp>
#include <algorithm>
Shucai Xiao's avatar
Shucai Xiao committed
33
34
35
36
37

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

Ted Themistokleous's avatar
Ted Themistokleous committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
inline bool all_but_last_dims_equal(const std::vector<size_t>& lens_a,
                                    const std::vector<size_t>& lens_b)
{
    if(lens_a.size() <= lens_b.size())
    {
        return std::equal(lens_a.begin(), lens_a.end(), lens_b.begin());
    }
    else
    {
        return std::equal(lens_b.begin(), lens_b.end(), lens_a.begin());
    }
};

void unsqueeze_last_op(module_ref mdl, int index, const std::vector<size_t>& out_shape)
{
    auto convert_ins =
        mdl->insert_instruction(std::prev(mdl->end()),
                                make_op("unsqueeze", {{"axes", {out_shape.size() - 1}}}),
                                std::prev(mdl->end())->inputs().at(index));
    mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), convert_ins);
}

Shucai Xiao's avatar
Shucai Xiao committed
60
61
62
63
64
65
66
67
68
struct parse_if : op_parser<parse_if>
{
    std::vector<op_desc> operators() const { return {{"If"}}; }

    std::vector<instruction_ref> parse(const op_desc& /*opd*/,
                                       onnx_parser& parser,
                                       const onnx_parser::node_info& info,
                                       std::vector<instruction_ref> args) const
    {
Shucai Xiao's avatar
Shucai Xiao committed
69
70
        const auto& then_graph = info.attributes.at("then_branch").g();
        const auto& else_graph = info.attributes.at("else_branch").g();
Shucai Xiao's avatar
Shucai Xiao committed
71

Shucai Xiao's avatar
Shucai Xiao committed
72
        if(args.front()->get_shape().elements() != 1)
Shucai Xiao's avatar
Shucai Xiao committed
73
        {
74
75
            MIGRAPHX_THROW("PARSE_IF: " + info.name +
                           " condition input can have only one element!");
Shucai Xiao's avatar
Shucai Xiao committed
76
77
        }

Shucai Xiao's avatar
Shucai Xiao committed
78
79
        std::string then_name = info.name + "_if";
        module_ref then_mdl   = parser.prog.create_module(then_name);
Shucai Xiao's avatar
Shucai Xiao committed
80

Shucai Xiao's avatar
Shucai Xiao committed
81
82
        std::string else_name = info.name + "_else";
        module_ref else_mdl   = parser.prog.create_module(else_name);
Shucai Xiao's avatar
Shucai Xiao committed
83

Shucai Xiao's avatar
Shucai Xiao committed
84
85
        // parse the then sub_graph
        parser.parse_graph(then_mdl, then_graph);
Shucai Xiao's avatar
Shucai Xiao committed
86

Shucai Xiao's avatar
Shucai Xiao committed
87
88
        // parse_the else sub_graph
        parser.parse_graph(else_mdl, else_graph);
Shucai Xiao's avatar
Shucai Xiao committed
89

Shucai Xiao's avatar
Shucai Xiao committed
90
91
        auto then_out_shapes = then_mdl->get_output_shapes();
        auto else_out_shapes = else_mdl->get_output_shapes();
92

93
94
        auto throw_shapes = [&]() {
            MIGRAPHX_THROW("PARSE_IF: " + info.name +
95
                           " then and else sub_graphs must have compatible shapes ");
96
97
98
99
100
101
        };

        if(then_out_shapes.size() != else_out_shapes.size())
        {
            throw_shapes();
        }
102

Ted Themistokleous's avatar
Ted Themistokleous committed
103
104
        // Add checks for each output shape
        for(int i = 0; i < then_out_shapes.size(); i++)
Shucai Xiao's avatar
Shucai Xiao committed
105
        {
Ted Themistokleous's avatar
Ted Themistokleous committed
106
107
108
            const auto& then_out_shape = then_out_shapes.at(i);
            const auto& else_out_shape = else_out_shapes.at(i);

Ted Themistokleous's avatar
Ted Themistokleous committed
109
            // Must have the same type for both if/else blocks by onnx spec
Ted Themistokleous's avatar
Ted Themistokleous committed
110
            if(then_out_shape.type() != else_out_shape.type())
111
            {
Ted Themistokleous's avatar
Ted Themistokleous committed
112
113
                MIGRAPHX_THROW("PARSE_IF: " + info.name +
                               " then and else sub_grahps must have same output type! " +
Ted Themistokleous's avatar
Ted Themistokleous committed
114
115
                               then_out_shape.type_string() + " vs " +
                               else_out_shape.type_string());
116
117
            }

Ted Themistokleous's avatar
Ted Themistokleous committed
118
            if(then_out_shape.dynamic() or else_out_shape.dynamic())
119
            {
Ted Themistokleous's avatar
Ted Themistokleous committed
120
121
                continue;
            }
122

Ted Themistokleous's avatar
Ted Themistokleous committed
123
124
            auto then_lens = then_out_shape.lens();
            auto else_lens = else_out_shape.lens();
125

Ted Themistokleous's avatar
Ted Themistokleous committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
            // Throw error if both branches have zero output shapes. Not possible for static
            // inputs
            if(then_lens.empty() and else_lens.empty())
            {
                throw_shapes();
            }

            auto handle_empty_branch = [](module_ref& mdl, int index, const shape& out_shape) {
                shape gen_shape(shape(out_shape.type(), {1}, {0}));
                auto literal_ins =
                    mdl->insert_literal(std::prev(mdl->end()), literal(gen_shape, {0}));
                auto unsqueeze_ins = mdl->insert_instruction(
                    std::prev(mdl->end()),
                    make_op("scalar", {{"scalar_bcst_dims", out_shape.lens()}}),
                    literal_ins);
                auto broad_ins = mdl->insert_instruction(
                    std::prev(mdl->end()),
                    make_op("multibroadcast", {{"out_lens", out_shape.lens()}}),
                    unsqueeze_ins);
                auto contig_out = mdl->insert_instruction(
                    std::prev(mdl->end()), make_op("contiguous"), broad_ins);
                mdl->replace_instruction(std::prev(mdl->end())->inputs().at(index), contig_out);
                return out_shape.lens();
            };

            // Handle one empty branch by setting output identical to the other
            // need to update the then_shape before we do further checks
            if(then_lens.empty())
            {
                then_lens = handle_empty_branch(then_mdl, i, else_out_shape);
            }
            else if(else_lens.empty())
            {
                else_lens = handle_empty_branch(else_mdl, i, then_out_shape);
            }

            // check equivalent length dims, and (x1,x2,.., xn, 1) == (x1,x2,..,xn)
            int rank_delta = abs((static_cast<int>(then_lens.size() - else_lens.size())));

            if(rank_delta == 1)
            {
                // make sure dims are equivalent in static shapes
                if(not all_but_last_dims_equal(then_lens, else_lens))
Ted Themistokleous's avatar
Ted Themistokleous committed
169
                {
Ted Themistokleous's avatar
Ted Themistokleous committed
170
                    throw_shapes();
Ted Themistokleous's avatar
Ted Themistokleous committed
171
172
                }

Ted Themistokleous's avatar
Ted Themistokleous committed
173
174
                auto last_then = then_lens.back();
                auto last_else = else_lens.back();
Ted Themistokleous's avatar
Ted Themistokleous committed
175

Ted Themistokleous's avatar
Ted Themistokleous committed
176
177
                // Find which dim to unsqueeze
                if((then_lens.size() < else_lens.size()) && (last_else == 1))
178
                {
Ted Themistokleous's avatar
Ted Themistokleous committed
179
                    unsqueeze_last_op(then_mdl, i, else_lens);
180
                }
Ted Themistokleous's avatar
Ted Themistokleous committed
181
                else if((then_lens.size() > else_lens.size()) && (last_then == 1))
182
                {
Ted Themistokleous's avatar
Ted Themistokleous committed
183
                    unsqueeze_last_op(else_mdl, i, then_lens);
184
185
                }
            }
Ted Themistokleous's avatar
Ted Themistokleous committed
186
187
188
189
            else if(rank_delta > 1)
            {
                throw_shapes();
            }
Shucai Xiao's avatar
Shucai Xiao committed
190
        }
Shucai Xiao's avatar
Shucai Xiao committed
191

Shucai Xiao's avatar
Shucai Xiao committed
192
193
194
        auto if_ret = info.add_instruction(make_op("if"), args, {then_mdl, else_mdl});
        auto out_s  = if_ret->get_shape();
        assert(out_s.type() == shape::tuple_type);
Shucai Xiao's avatar
Shucai Xiao committed
195

Shucai Xiao's avatar
Shucai Xiao committed
196
197
198
199
200
201
        const auto& vec_shapes = out_s.sub_shapes();
        std::vector<instruction_ref> out_inss;
        for(std::size_t i = 0; i < vec_shapes.size(); ++i)
        {
            auto ret = info.add_instruction(make_op("get_tuple_elem", {{"index", i}}), if_ret);
            out_inss.push_back(ret);
Shucai Xiao's avatar
Shucai Xiao committed
202
        }
Shucai Xiao's avatar
Shucai Xiao committed
203
204

        return out_inss;
Shucai Xiao's avatar
Shucai Xiao committed
205
206
207
208
209
210
    }
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx