"vscode:/vscode.git/clone" did not exist on "6a6328a5e941751bd3481442a8136f7d563cfdfd"
fuse_reduce.cpp 10.9 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
Paul's avatar
Paul committed
34
#include <migraphx/register_op.hpp>
Paul's avatar
Paul committed
35
#include <iterator>
Paul's avatar
Paul committed
36
#include <map>
Paul's avatar
Paul committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct fused_reduce
{
    std::vector<std::int64_t> axes{};

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.axes, "axes"));
    }

    shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
    {
        if(mods.size() != 1)
            MIGRAPHX_THROW("should have one submodule.");
Paul's avatar
Format  
Paul committed
55
        auto* sm = mods.front();
Paul's avatar
Paul committed
56
57
        if(sm->get_output_shapes().size() != 1)
            MIGRAPHX_THROW("Only one output supported");
Paul's avatar
Paul committed
58
        check_shapes{inputs, *this}.has(sm->get_parameter_shapes().size()).same_dims();
Paul's avatar
Format  
Paul committed
59
60
        auto s    = inputs.at(0);
        auto lens = s.lens();
Paul's avatar
Format  
Paul committed
61
        if(lens != sm->get_output_shapes().front().lens())
Paul's avatar
Paul committed
62
        {
Paul's avatar
Paul committed
63
64
65
66
            for(const auto& axis : axes)
            {
                lens[axis] = 1;
            }
Paul's avatar
Paul committed
67
        }
Paul's avatar
Paul committed
68

Paul's avatar
Format  
Paul committed
69
70
        return shape::from_permutation(
            sm->get_output_shapes().front().type(), lens, find_permutation(inputs));
Paul's avatar
Paul committed
71
72
73
74
    }

    std::string name() const { return "fused_reduce"; }
};
Paul's avatar
Paul committed
75
MIGRAPHX_REGISTER_OP(fused_reduce);
Paul's avatar
Paul committed
76

Paul's avatar
Paul committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
static std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
{
    std::unordered_map<instruction_ref, instruction_ref> result;
    auto names = sm->get_parameter_names();
    std::sort(names.begin(), names.end());
    assert(names.size() == inputs.size());
    std::transform(names.begin(),
                   names.end(),
                   inputs.begin(),
                   std::inserter(result, result.end()),
                   [&](const auto& name, auto input) {
                       return std::make_pair(input, sm->get_parameter(name));
                   });
    return result;
}

Paul's avatar
Format  
Paul committed
94
95
96
static void insert_params(module_ref sm,
                          instruction_ref ins,
                          std::unordered_map<instruction_ref, instruction_ref>& map_ins)
Paul's avatar
Paul committed
97
98
{
    auto n = sm->get_parameter_shapes().size();
Paul's avatar
Format  
Paul committed
99
    for(auto input : ins->inputs())
Paul's avatar
Paul committed
100
101
102
103
104
105
106
107
    {
        if(contains(map_ins, input))
            continue;
        // TODO: Ensure standard shape
        map_ins[input] = sm->add_parameter("x" + std::to_string(n++), input->get_shape());
    }
}

Paul's avatar
Format  
Paul committed
108
109
110
static auto insert_ins_in_submodule(module_ref sm,
                                    instruction_ref ins,
                                    std::unordered_map<instruction_ref, instruction_ref>& map_ins)
Paul's avatar
Paul committed
111
112
113
114
115
116
117
118
119
120
121
{
    insert_params(sm, ins, map_ins);
    return sm->add_instructions({ins}, map_ins);
}

static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
{
    std::unordered_map<instruction_ref, instruction_ref> map_ins;
    return insert_ins_in_submodule(sm, ins, map_ins);
}

Paul's avatar
Format  
Paul committed
122
123
124
125
static auto
insert_module_in_submodule(module_ref sm,
                           instruction_ref ins,
                           std::unordered_map<instruction_ref, instruction_ref>& map_ins)
Paul's avatar
Paul committed
126
127
{
    insert_params(sm, ins, map_ins);
Paul's avatar
Format  
Paul committed
128
    auto* m        = ins->module_inputs().front();
Paul's avatar
Paul committed
129
    auto param_map = get_ins_param_map(ins->inputs(), m);
Paul's avatar
Format  
Paul committed
130
    for(auto&& [input, param] : param_map)
Paul's avatar
Paul committed
131
132
133
134
135
136
    {
        map_ins[param] = map_ins.at(input);
    }
    return sm->add_instructions(m, map_ins);
}

Paul's avatar
Format  
Paul committed
137
138
static std::vector<instruction_ref>
find_inputs(module_ref sm, const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
Paul's avatar
Paul committed
139
140
141
{
    std::vector<instruction_ref> result;
    std::map<std::string, instruction_ref> names;
Paul's avatar
Format  
Paul committed
142
    for(auto&& [input, param] : map_ins)
Paul's avatar
Paul committed
143
144
145
146
147
    {
        if(not sm->has_instruction(param))
            continue;
        if(param->name() != "@param")
            continue;
Paul's avatar
Format  
Paul committed
148
149
        auto v      = param->get_operator().to_value();
        auto name   = v.at("parameter").to<std::string>();
Paul's avatar
Paul committed
150
151
152
153
154
155
156
157
        names[name] = input;
    }
    std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
        return p.second;
    });
    return result;
}

Paul's avatar
Paul committed
158
159
160
161
162
163
164
static void create_reduce_modules(module_pass_manager& mpm)
{
    std::size_t n = 0;
    for(auto ins : iterator_for(mpm.get_module()))
    {
        if(not ins->get_operator().attributes().get("reduce", false))
            continue;
Paul's avatar
Format  
Paul committed
165
        if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
166
167
            continue;

Paul's avatar
Format  
Paul committed
168
169
        auto* rm =
            mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
Paul's avatar
Paul committed
170
171
        rm->set_bypass();

Paul's avatar
Paul committed
172
        rm->add_return(insert_ins_in_submodule(rm, ins));
Paul's avatar
Paul committed
173

Paul's avatar
Paul committed
174
        auto v = ins->get_operator().to_value();
Paul's avatar
Format  
Paul committed
175
176
        mpm.get_module().replace_instruction(
            ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
Paul's avatar
Paul committed
177
178
179
180
181
182
    }
}

static std::vector<instruction_ref> get_returns(module& m)
{
    auto last = std::prev(m.end());
Paul's avatar
Format  
Paul committed
183
    if(last->name() == "@return")
Paul's avatar
Paul committed
184
185
186
187
        return last->inputs();
    return {last};
}

Paul's avatar
Paul committed
188
189
190
191
192
namespace {
struct find_pointwise_reduce
{
    auto matcher() const
    {
Paul's avatar
Format  
Paul committed
193
194
        return match::name("fused_reduce")(match::any_of[match::inputs()](
            match::name("pointwise")(match::used_once()).bind("pointwise")));
Paul's avatar
Paul committed
195
196
197
198
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
Paul's avatar
Format  
Paul committed
199
200
        auto reduce = r.result;
        auto pw     = r.instructions["pointwise"];
Paul's avatar
Paul committed
201
202
203

        const auto* pm = pw->module_inputs().front();
        // const auto* old_rm = reduce->module_inputs().front();
Paul's avatar
Format  
Paul committed
204
        auto* rm = mpm.create_module(pm->name() + ":reduce");
Paul's avatar
Paul committed
205
206
207
208
        rm->set_bypass();

        std::unordered_map<instruction_ref, instruction_ref> map_ins;
        // Insert pointwise
Paul's avatar
Format  
Paul committed
209
        auto rins   = insert_ins_in_submodule(rm, pw, map_ins).front();
Paul's avatar
Paul committed
210
211
212
213
214
215
216
217
218
        map_ins[pw] = rins;
        // Insert fused_reduce
        insert_module_in_submodule(rm, reduce, map_ins);

        auto new_inputs = find_inputs(rm, map_ins);
        mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
    }
};

Paul's avatar
Paul committed
219
220
struct find_reduce_pointwise
{
Paul's avatar
Format  
Paul committed
221
    template <class... Ms>
Paul's avatar
Paul committed
222
223
    static auto match_broadcast(Ms... ms)
    {
Paul's avatar
Format  
Paul committed
224
225
        return match::skip(match::name("contiguous"))(
            match::name("multibroadcast")(match::arg(0)(ms...)).bind("broadcast"));
Paul's avatar
Paul committed
226
227
    }

Paul's avatar
Format  
Paul committed
228
    template <class... Ms>
Paul's avatar
Paul committed
229
230
231
232
233
    static auto any_input(Ms... ms)
    {
        return match::any_of[match::inputs()](match::any(ms...).bind("input"));
    }

Paul's avatar
Paul committed
234
235
    auto matcher() const
    {
Paul's avatar
Format  
Paul committed
236
        auto reduce       = match::name("fused_reduce")(match::used_once()).bind("reduce");
Paul's avatar
Paul committed
237
238
239
        auto reduce_input = any_input(reduce);
        auto broadcast_reduce_input = any_input(match_broadcast(reduce));
        return match::name("pointwise")(match::any_of(reduce_input, broadcast_reduce_input));
Paul's avatar
Paul committed
240
241
242
243
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
Paul's avatar
Format  
Paul committed
244
        auto pw     = r.result;
Paul's avatar
Paul committed
245
        auto reduce = r.instructions["reduce"];
Paul's avatar
Format  
Paul committed
246
        auto input  = r.instructions["input"];
Paul's avatar
Paul committed
247

Paul's avatar
Paul committed
248
        const auto* old_rm = reduce->module_inputs().front();
Paul's avatar
Format  
Paul committed
249
        auto* rm           = mpm.create_module(old_rm->name() + ":pointwise");
Paul's avatar
Paul committed
250
        rm->set_bypass();
Paul's avatar
Paul committed
251
        std::unordered_map<instruction_ref, instruction_ref> map_ins;
Paul's avatar
Paul committed
252
        // Copy module instructions
Paul's avatar
Paul committed
253
        insert_module_in_submodule(rm, reduce, map_ins);
Paul's avatar
Paul committed
254
255
        if(contains(r.instructions, "broadcast"))
        {
Paul's avatar
Format  
Paul committed
256
            auto broadcast                       = r.instructions["broadcast"];
Paul's avatar
Paul committed
257
            map_ins[broadcast->inputs().front()] = get_returns(*rm).front();
Paul's avatar
Format  
Paul committed
258
259
            auto bout                            = insert_ins_in_submodule(rm, broadcast, map_ins);
            map_ins[input]                       = bout.front();
Paul's avatar
Paul committed
260
261
262
263
264
        }
        else
        {
            map_ins[input] = get_returns(*rm).front();
        }
Paul's avatar
Paul committed
265

Paul's avatar
Paul committed
266
267
268
269
270
        auto out = insert_ins_in_submodule(rm, pw, map_ins);
        rm->replace_return(out);

        auto new_inputs = find_inputs(rm, map_ins);
        mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
Paul's avatar
Paul committed
271
272
    }
};
Paul's avatar
Paul committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286

struct find_reduce_reduce
{
    auto matcher() const
    {
        return match::name("fused_reduce")(match::any_of[match::inputs()](
            match::name("fused_reduce")(match::used_once()).bind("reduce")));
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
        auto reduce1 = r.result;
        auto reduce2 = r.instructions["reduce"];

Paul's avatar
Format  
Paul committed
287
        if(reduce1->get_operator() != reduce2->get_operator())
Paul's avatar
Paul committed
288
289
290
291
            return;

        const auto* rm1 = reduce1->module_inputs().front();
        const auto* rm2 = reduce2->module_inputs().front();
Paul's avatar
Format  
Paul committed
292
        auto* rm        = mpm.create_module(rm1->name() + ":" + rm2->name());
Paul's avatar
Paul committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        rm->set_bypass();

        std::unordered_map<instruction_ref, instruction_ref> map_ins;
        // Copy reduce1 instructions
        insert_module_in_submodule(rm, reduce2, map_ins);
        map_ins[reduce2] = get_returns(*rm).front();

        auto out = insert_module_in_submodule(rm, reduce1, map_ins);
        rm->replace_return(out);

        auto new_inputs = find_inputs(rm, map_ins);
        mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
    }
};

Paul's avatar
Format  
Paul committed
308
} // namespace
Paul's avatar
Paul committed
309
310
311
312
313

void fuse_reduce::apply(module_pass_manager& mpm) const
{
    create_reduce_modules(mpm);
    mpm.run_pass(dead_code_elimination{});
Paul's avatar
Format  
Paul committed
314
    for(int i = 0; i < 4; i++)
Paul's avatar
Paul committed
315
    {
Paul's avatar
Format  
Paul committed
316
317
        match::find_matches(
            mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
Paul's avatar
Paul committed
318
319
        mpm.run_pass(dead_code_elimination{});
    }
Paul's avatar
Paul committed
320
321
322
323
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx