eliminate_pad.cpp 2.49 KB
Newer Older
1
#include <migraphx/eliminate_pad.hpp>
2
3
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
Shucai Xiao's avatar
Shucai Xiao committed
4
5
6
7
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pad.hpp>
8
9
10
11
12
13
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

14
void eliminate_pad::apply(program& p) const
15
16
17
{
    for(auto ins : iterator_for(p))
    {
Khalique's avatar
Khalique committed
18
19
        const std::string& op_name = ins->name();
        if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
20
            continue;
Khalique's avatar
Khalique committed
21
22
23
24
25
26
27
28
        auto input = ins->inputs().front();
        if(input->name() != "pad")
            continue;
        if(op_name == "convolution")
            update_op(op::convolution{}, input, ins, p);
        else if(op_name == "im2col")
            update_op(op::im2col{}, input, ins, p);
        else if(op_name == "pooling")
29
            update_pooling(input, ins, p);
30
31
32
    }
}

Khalique's avatar
Khalique committed
33
template <class T>
34
void eliminate_pad::update_op(T,
Khalique's avatar
Khalique committed
35
36
37
                              const instruction_ref& input,
                              const instruction_ref& ins,
                              program& p) const
38
{
Khalique's avatar
Khalique committed
39
    auto pad_op = any_cast<op::pad>(input->get_operator());
40
41
    if(!pad_op.symmetric())
        return;
Khalique's avatar
Khalique committed
42

43
44
45
46
    auto kdims    = input->get_shape().lens().size() - 2;
    auto kdims_it = pad_op.pads.begin() + 2;

    std::vector<size_t> new_pads(kdims_it, kdims_it + kdims);
47

Khalique's avatar
Khalique committed
48
    T op       = any_cast<T>(ins->get_operator());
49
    op.padding = new_pads;
Khalique's avatar
Khalique committed
50

Khalique's avatar
Khalique committed
51
52
    std::vector<instruction_ref> new_inputs{ins->inputs()};
    new_inputs.front() = input->inputs().front();
Khalique's avatar
Khalique committed
53

Khalique's avatar
Khalique committed
54
    p.replace_instruction(ins, op, new_inputs);
Khalique's avatar
Khalique committed
55
}
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
void eliminate_pad::update_pooling(const instruction_ref& input,
                                   const instruction_ref& ins,
                                   program& p) const
{
    auto pad_op = any_cast<op::pad>(input->get_operator());
    if(!pad_op.symmetric())
        return;

    auto kdims    = input->get_shape().lens().size() - 2;
    auto kdims_it = pad_op.pads.begin() + 2;

    std::vector<size_t> new_pads(kdims_it, kdims_it + kdims);

    auto op = any_cast<op::pooling>(ins->get_operator());
    if(op.mode == "average")
    {
        return;
    }

    op.padding = new_pads;

    std::vector<instruction_ref> new_inputs{ins->inputs()};
    new_inputs.front() = input->inputs().front();

    p.replace_instruction(ins, op, new_inputs);
}

84
85
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx