rewrite_pooling.cpp 1.94 KB
Newer Older
1
2
3
4
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/op/reshape.hpp>
6
#include <migraphx/op/reduce_mean.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
7
#include <migraphx/op/reduce_max.hpp>
8
9
#include <migraphx/make_op.hpp>

10
11
12
13
14
#include <migraphx/program.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

15
void rewrite_pooling::apply(module& m) const
16
{
17
    for(auto ins : iterator_for(m))
18
    {
Paul's avatar
Paul committed
19
        if(ins->name() != "pooling")
20
            continue;
Paul's avatar
Paul committed
21
        if(ins->inputs().empty())
22
            continue;
23
24
25
        auto&& s = ins->inputs().front()->get_shape();
        if(not s.standard())
            continue;
26
        auto&& op = any_cast<op::pooling>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
27
        if(!std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
28
            continue;
Shucai Xiao's avatar
Shucai Xiao committed
29
        if(!std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
30
            continue;
Shucai Xiao's avatar
Shucai Xiao committed
31
32
        auto lens = s.lens();
        if(!std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
33
            continue;
Paul's avatar
Paul committed
34
35
        std::int64_t n = s.lens()[0];
        std::int64_t c = s.lens()[1];
36
        auto reshape   = m.insert_instruction(
37
            ins, make_op("reshape", {{"dims", {n * c, -1}}}), ins->inputs().front());
Shucai Xiao's avatar
Shucai Xiao committed
38
39
40
        instruction_ref pooling{};

        // average pooling
41
        if(op.mode == op::pooling_mode::average)
Shucai Xiao's avatar
Shucai Xiao committed
42
        {
43
            pooling = m.insert_instruction(ins, make_op("reduce_mean", {{"axes", {1}}}), reshape);
Shucai Xiao's avatar
Shucai Xiao committed
44
45
46
47
        }
        // max pooling
        else
        {
48
            pooling = m.insert_instruction(ins, make_op("reduce_max", {{"axes", {1}}}), reshape);
Shucai Xiao's avatar
Shucai Xiao committed
49
50
51
52
53
        }

        std::vector<int64_t> rsp_lens(lens.size(), 1);
        rsp_lens[0] = n;
        rsp_lens[1] = c;
54
        m.replace_instruction(ins, make_op("reshape", {{"dims", rsp_lens}}), pooling);
55
56
57
58
59
    }
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx