pad_rewrite.cpp 1.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <migraphx/pad_rewrite.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void pad_rewrite::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
        if(ins->name() != "pad")
            continue;
        for (auto output : ins->outputs())
        {
            auto op_name = output->name();
            if(op_name == "convolution")
                update_op(op::convolution{}, ins, output, p);
            else if(op_name == "im2col")
                update_op(op::im2col{}, ins, output, p);
            else if(op_name == "pooling")
                update_op(op::pooling{}, ins, output, p);
        }
    }
}

template<class T>
void pad_rewrite::update_op(T, instruction_ref ins, instruction_ref output, program& p) const
{
    auto pad_op = any_cast<op::pad>(ins->get_operator());
    if(!pad_op.symmetric())
        return;
    
    std::vector<int64_t> pads = pad_op.pads;
    assert(pads.size() == 8); // ensure input being padded has 4 dims (*2 for font and back padding)
    std::array<size_t, 2> new_pads{static_cast<size_t>(pads[2]),static_cast<size_t>(pads[3])};

    T op = any_cast<T>(output->get_operator()); 
    op.padding = new_pads;
    
    std::vector<instruction_ref> new_inputs{output->inputs()};
    new_inputs.front() = ins->inputs().front();
    
    p.replace_instruction(output, op, new_inputs);
} 

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx