fuse_reduce.cpp 5.99 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
Paul's avatar
Paul committed
34
#include <migraphx/register_op.hpp>
Paul's avatar
Paul committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <iterator>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct fused_reduce
{
    std::vector<std::int64_t> axes{};

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.axes, "axes"));
    }

    shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
    {
        if(mods.size() != 1)
            MIGRAPHX_THROW("should have one submodule.");
Paul's avatar
Format  
Paul committed
54
        auto* sm = mods.front();
Paul's avatar
Paul committed
55
56
        if(sm->get_output_shapes().size() != 1)
            MIGRAPHX_THROW("Only one output supported");
Paul's avatar
Paul committed
57
        check_shapes{inputs, *this}.has(sm->get_parameter_shapes().size()).same_dims();
Paul's avatar
Format  
Paul committed
58
59
        auto s    = inputs.at(0);
        auto lens = s.lens();
Paul's avatar
Format  
Paul committed
60
        if(lens != sm->get_output_shapes().front().lens())
Paul's avatar
Paul committed
61
        {
Paul's avatar
Paul committed
62
63
64
65
            for(const auto& axis : axes)
            {
                lens[axis] = 1;
            }
Paul's avatar
Paul committed
66
        }
Paul's avatar
Paul committed
67

Paul's avatar
Format  
Paul committed
68
69
        return shape::from_permutation(
            sm->get_output_shapes().front().type(), lens, find_permutation(inputs));
Paul's avatar
Paul committed
70
71
72
73
    }

    std::string name() const { return "fused_reduce"; }
};
Paul's avatar
Paul committed
74
MIGRAPHX_REGISTER_OP(fused_reduce);
Paul's avatar
Paul committed
75
76
77
78
79
80
81
82

static void create_reduce_modules(module_pass_manager& mpm)
{
    std::size_t n = 0;
    for(auto ins : iterator_for(mpm.get_module()))
    {
        if(not ins->get_operator().attributes().get("reduce", false))
            continue;
Paul's avatar
Format  
Paul committed
83
        if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
84
85
            continue;

Paul's avatar
Format  
Paul committed
86
87
        auto* rm =
            mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
Paul's avatar
Paul committed
88
89
90
91
        rm->set_bypass();

        // TODO: Ensure standard shape
        auto x0 = rm->add_parameter("x0", ins->inputs().front()->get_shape());
Paul's avatar
Format  
Paul committed
92
        auto r  = rm->add_instruction(ins->get_operator(), x0);
Paul's avatar
Paul committed
93
94
        rm->add_return({r});

Paul's avatar
Paul committed
95
        auto v = ins->get_operator().to_value();
Paul's avatar
Format  
Paul committed
96
97
        mpm.get_module().replace_instruction(
            ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
Paul's avatar
Paul committed
98
99
100
    }
}

Paul's avatar
Format  
Paul committed
101
102
static std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
Paul's avatar
Paul committed
103
104
105
106
107
{
    std::unordered_map<instruction_ref, instruction_ref> result;
    auto names = sm->get_parameter_names();
    std::sort(names.begin(), names.end());
    assert(names.size() == inputs.size());
Paul's avatar
Format  
Paul committed
108
109
110
111
112
113
114
    std::transform(names.begin(),
                   names.end(),
                   inputs.begin(),
                   std::inserter(result, result.end()),
                   [&](const auto& name, auto input) {
                       return std::make_pair(input, sm->get_parameter(name));
                   });
Paul's avatar
Paul committed
115
116
117
118
119
120
    return result;
}

static std::vector<instruction_ref> get_returns(module& m)
{
    auto last = std::prev(m.end());
Paul's avatar
Format  
Paul committed
121
    if(last->name() == "@return")
Paul's avatar
Paul committed
122
123
124
125
126
127
128
129
        return last->inputs();
    return {last};
}

struct find_reduce_pointwise
{
    auto matcher() const
    {
Paul's avatar
Format  
Paul committed
130
131
        return match::name("pointwise")(match::any_of[match::inputs()](
            match::name("fused_reduce")(match::used_once()).bind("reduce")));
Paul's avatar
Paul committed
132
133
134
135
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
Paul's avatar
Format  
Paul committed
136
        auto ins    = r.result;
Paul's avatar
Paul committed
137
138
        auto reduce = r.instructions["reduce"];

Paul's avatar
Paul committed
139
        const auto* old_rm = reduce->module_inputs().front();
Paul's avatar
Format  
Paul committed
140
        auto* rm           = mpm.create_module(old_rm->name() + ":pointwise");
Paul's avatar
Paul committed
141
142
143
        rm->set_bypass();
        // Copy module instructions
        rm->add_instructions(old_rm);
Paul's avatar
Format  
Paul committed
144
        auto map_ins    = get_ins_param_map(reduce->inputs(), rm);
Paul's avatar
Paul committed
145
        auto new_inputs = reduce->inputs();
Paul's avatar
Format  
Paul committed
146
        for(auto input : ins->inputs())
Paul's avatar
Paul committed
147
148
149
        {
            if(contains(map_ins, input))
                continue;
Paul's avatar
Format  
Paul committed
150
            if(input == reduce)
Paul's avatar
Paul committed
151
            {
Paul's avatar
Paul committed
152
153
154
155
                map_ins[input] = get_returns(*rm).front();
            }
            else
            {
Paul's avatar
Format  
Paul committed
156
157
                map_ins[input] =
                    rm->add_parameter("x" + std::to_string(new_inputs.size()), input->get_shape());
Paul's avatar
Paul committed
158
                new_inputs.push_back(input);
Paul's avatar
Paul committed
159
160
161
            }
        }

Paul's avatar
Paul committed
162
163
        auto out = rm->add_instructions({ins}, map_ins);
        rm->add_return(out);
Paul's avatar
Paul committed
164
165
166
167
168
169
170
171
172
        mpm.get_module().replace_instruction(ins, reduce->get_operator(), new_inputs, {rm});
    }
};

void fuse_reduce::apply(module_pass_manager& mpm) const
{
    create_reduce_modules(mpm);
    mpm.run_pass(dead_code_elimination{});
    match::find_matches(mpm, find_reduce_pointwise{});
Paul's avatar
Paul committed
173
    mpm.run_pass(dead_code_elimination{});
Paul's avatar
Paul committed
174
175
176
177
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx