pad_rewrite.cpp 1.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include <migraphx/pad_rewrite.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void pad_rewrite::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
Khalique's avatar
Khalique committed
15
16
        const std::string& op_name = ins->name();
        if(op_name != "convolution" and op_name != "im2col" and op_name != "pooling")
17
            continue;
Khalique's avatar
Khalique committed
18
19
20
21
22
23
24
25
26
        auto input = ins->inputs().front();
        if(input->name() != "pad")
            continue;
        if(op_name == "convolution")
            update_op(op::convolution{}, input, ins, p);
        else if(op_name == "im2col")
            update_op(op::im2col{}, input, ins, p);
        else if(op_name == "pooling")
            update_op(op::pooling{}, input, ins, p);
27
28
29
    }
}

Khalique's avatar
Khalique committed
30
template <class T>
Khalique's avatar
Khalique committed
31
void pad_rewrite::update_op(T, const instruction_ref& input, const instruction_ref& ins, program& p) const
32
{
Khalique's avatar
Khalique committed
33
    auto pad_op = any_cast<op::pad>(input->get_operator());
34
35
    if(!pad_op.symmetric())
        return;
Khalique's avatar
Khalique committed
36

37
    std::vector<int64_t> pads = pad_op.pads;
Khalique's avatar
Khalique committed
38
    std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]), static_cast<size_t>(pads[3])};
39

Khalique's avatar
Khalique committed
40
    T op       = any_cast<T>(ins->get_operator());
41
    op.padding = new_pads;
Khalique's avatar
Khalique committed
42

Khalique's avatar
Khalique committed
43
44
    std::vector<instruction_ref> new_inputs{ins->inputs()};
    new_inputs.front() = input->inputs().front();
Khalique's avatar
Khalique committed
45

Khalique's avatar
Khalique committed
46
    p.replace_instruction(ins, op, new_inputs);
Khalique's avatar
Khalique committed
47
}
48
49
50

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx