fuse_reduce.cpp 5.72 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/*
 * The MIT License (MIT)
 *
 * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <iterator>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct fused_reduce
{
    std::vector<std::int64_t> axes{};

    template <class Self, class F>
    static auto reflect(Self& self, F f)
    {
        return pack(f(self.axes, "axes"));
    }

    shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
    {
        if(mods.size() != 1)
        {
            MIGRAPHX_THROW("should have one submodule.");
        }
Paul's avatar
Format  
Paul committed
55
        auto* sm = mods.front();
Paul's avatar
Paul committed
56
        check_shapes{inputs, *this}.has(sm->get_parameter_shapes().size()).same_dims();
Paul's avatar
Format  
Paul committed
57
58
        auto s    = inputs.at(0);
        auto lens = s.lens();
Paul's avatar
Paul committed
59
60
61
62
        for(const auto& axis : axes)
        {
            lens[axis] = 1;
        }
Paul's avatar
Format  
Paul committed
63
        if(sm->get_output_shapes().size() != 1)
Paul's avatar
Paul committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
            MIGRAPHX_THROW("Only one output supported");
        return inputs[0].with_lens(sm->get_output_shapes().front().type(), lens);
    }

    std::string name() const { return "fused_reduce"; }
};

static void create_reduce_modules(module_pass_manager& mpm)
{
    std::size_t n = 0;
    for(auto ins : iterator_for(mpm.get_module()))
    {
        if(not ins->get_operator().attributes().get("reduce", false))
            continue;
Paul's avatar
Format  
Paul committed
78
        if(ins->inputs().size() != 1)
Paul's avatar
Paul committed
79
80
            continue;

Paul's avatar
Format  
Paul committed
81
82
        auto* rm =
            mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
Paul's avatar
Paul committed
83
84
85
86
        rm->set_bypass();

        // TODO: Ensure standard shape
        auto x0 = rm->add_parameter("x0", ins->inputs().front()->get_shape());
Paul's avatar
Format  
Paul committed
87
        auto r  = rm->add_instruction(ins->get_operator(), x0);
Paul's avatar
Paul committed
88
89
90
91
92
93
94
        rm->add_return({r});

        // TODO: Set axes
        mpm.get_module().replace_instruction(ins, make_op("fused_reduce"), ins->inputs(), {rm});
    }
}

Paul's avatar
Format  
Paul committed
95
96
static std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
Paul's avatar
Paul committed
97
98
99
100
101
{
    std::unordered_map<instruction_ref, instruction_ref> result;
    auto names = sm->get_parameter_names();
    std::sort(names.begin(), names.end());
    assert(names.size() == inputs.size());
Paul's avatar
Format  
Paul committed
102
103
104
105
106
107
108
    std::transform(names.begin(),
                   names.end(),
                   inputs.begin(),
                   std::inserter(result, result.end()),
                   [&](const auto& name, auto input) {
                       return std::make_pair(input, sm->get_parameter(name));
                   });
Paul's avatar
Paul committed
109
110
111
112
113
114
    return result;
}

static std::vector<instruction_ref> get_returns(module& m)
{
    auto last = std::prev(m.end());
Paul's avatar
Format  
Paul committed
115
    if(last->name() == "@return")
Paul's avatar
Paul committed
116
117
118
119
120
121
122
123
        return last->inputs();
    return {last};
}

struct find_reduce_pointwise
{
    auto matcher() const
    {
Paul's avatar
Format  
Paul committed
124
125
        return match::name("pointwise")(match::any_of[match::inputs()](
            match::name("fused_reduce")(match::used_once()).bind("reduce")));
Paul's avatar
Paul committed
126
127
128
129
    }

    void apply(module_pass_manager& mpm, const match::matcher_result& r) const
    {
Paul's avatar
Format  
Paul committed
130
        auto ins    = r.result;
Paul's avatar
Paul committed
131
132
133
        auto reduce = r.instructions["reduce"];

        auto* old_rm = reduce->module_inputs().front();
Paul's avatar
Format  
Paul committed
134
        auto* rm     = mpm.create_module(old_rm->name() + ":pointwise");
Paul's avatar
Paul committed
135
        // Copy module
Paul's avatar
Format  
Paul committed
136
137
        *rm             = *old_rm;
        auto map_ins    = get_ins_param_map(reduce->inputs(), rm);
Paul's avatar
Paul committed
138
        auto new_inputs = reduce->inputs();
Paul's avatar
Format  
Paul committed
139
        for(auto input : ins->inputs())
Paul's avatar
Paul committed
140
141
142
        {
            if(contains(map_ins, input))
                continue;
Paul's avatar
Format  
Paul committed
143
            if(input == reduce)
Paul's avatar
Paul committed
144
            {
Paul's avatar
Paul committed
145
146
147
148
                map_ins[input] = get_returns(*rm).front();
            }
            else
            {
Paul's avatar
Format  
Paul committed
149
150
                map_ins[input] =
                    rm->add_parameter("x" + std::to_string(new_inputs.size()), input->get_shape());
Paul's avatar
Paul committed
151
                new_inputs.push_back(input);
Paul's avatar
Paul committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
            }
        }

        auto out = rm->insert_instructions(std::prev(rm->end()), {ins}, map_ins);
        rm->replace_return(out);
        mpm.get_module().replace_instruction(ins, reduce->get_operator(), new_inputs, {rm});
    }
};

void fuse_reduce::apply(module_pass_manager& mpm) const
{
    create_reduce_modules(mpm);
    mpm.run_pass(dead_code_elimination{});
    match::find_matches(mpm, find_reduce_pointwise{});
Paul's avatar
Paul committed
166
    mpm.run_pass(dead_code_elimination{});
Paul's avatar
Paul committed
167
168
169
170
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx