rewrite_pooling.cpp 1.83 KB
Newer Older
1
2
3
4
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/pooling.hpp>
Paul's avatar
Paul committed
5
#include <migraphx/op/reshape.hpp>
6
#include <migraphx/op/reduce_mean.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
7
#include <migraphx/op/reduce_max.hpp>
8
9
10
11
12
#include <migraphx/program.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

13
void rewrite_pooling::apply(module& prog) const
14
15
16
{
    for(auto ins : iterator_for(prog))
    {
Paul's avatar
Paul committed
17
        if(ins->name() != "pooling")
18
            continue;
Paul's avatar
Paul committed
19
        if(ins->inputs().empty())
20
            continue;
21
22
23
        auto&& s = ins->inputs().front()->get_shape();
        if(not s.standard())
            continue;
24
        auto&& op = any_cast<op::pooling>(ins->get_operator());
Shucai Xiao's avatar
Shucai Xiao committed
25
        if(!std::all_of(op.padding.begin(), op.padding.end(), [](auto i) { return i == 0; }))
26
            continue;
Shucai Xiao's avatar
Shucai Xiao committed
27
        if(!std::all_of(op.stride.begin(), op.stride.end(), [](auto i) { return i == 1; }))
28
            continue;
Shucai Xiao's avatar
Shucai Xiao committed
29
30
        auto lens = s.lens();
        if(!std::equal(lens.begin() + 2, lens.end(), op.lengths.begin(), op.lengths.end()))
31
            continue;
Paul's avatar
Paul committed
32
33
        std::int64_t n = s.lens()[0];
        std::int64_t c = s.lens()[1];
Paul's avatar
Paul committed
34
35
        auto reshape =
            prog.insert_instruction(ins, op::reshape{{n * c, -1}}, ins->inputs().front());
Shucai Xiao's avatar
Shucai Xiao committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        instruction_ref pooling{};

        // average pooling
        if(op.mode == "average")
        {
            pooling = prog.insert_instruction(ins, op::reduce_mean{{1}}, reshape);
        }
        // max pooling
        else
        {
            pooling = prog.insert_instruction(ins, op::reduce_max{{1}}, reshape);
        }

        std::vector<int64_t> rsp_lens(lens.size(), 1);
        rsp_lens[0] = n;
        rsp_lens[1] = c;
        prog.replace_instruction(ins, op::reshape{rsp_lens}, pooling);
53
54
55
56
57
    }
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx