eliminate_identity.cpp 971 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

void eliminate_identity::apply(program& p) const
{
    for(auto ins : iterator_for(p))
    {
17
18
19
20
        std::vector<instruction_ref> new_ins_inputs = ins->inputs();
        // check each input arg for identity ops,
        // replace with the input of the respective identity
        for(instruction_ref& input : new_ins_inputs)
21
        {
22
23
            if (input->name() == "identity"){
                input = input->inputs().at(0);
24
25
            }
        }
26
27
        if (new_ins_inputs != ins->inputs())
            p.replace_instruction(ins, ins->get_operator(), new_ins_inputs);
28
29
30
31
32
    }
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx