Commit 71d4f4c5 authored by Khalique's avatar Khalique
Browse files

added files for eliminating identity

parent 154f21f6
...@@ -10,6 +10,7 @@ add_library(migraphx ...@@ -10,6 +10,7 @@ add_library(migraphx
eliminate_allocation.cpp eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_concat.cpp eliminate_concat.cpp
eliminate_identity.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
env.cpp env.cpp
......
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void eliminate_identity::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->get_operator().name() == "identity")
{
if(ins != p.end())
{
instruction_ref identity_input{ins->inputs().at(0)};
auto next_ins = std::next(ins);
std::vector<instruction_ref> next_ins_inputs{next_ins->inputs()};
for (auto& input : next_ins_inputs)
{
if(input == ins)
{
input = identity_input;
}
}
p.replace_instruction(next_ins, next_ins->get_operator(), next_ins_inputs);
}
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_IDENTITY_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Remove identity instructions.
*/
struct eliminate_identity
{
std::string name() const { return "eliminate_identity"; }
void apply(program& p) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <migraphx/fwd_conv_batchnorm_rewrite.hpp> #include <migraphx/fwd_conv_batchnorm_rewrite.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
namespace migraphx { namespace migraphx {
...@@ -30,6 +31,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -30,6 +31,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
return return
{ {
dead_code_elimination{}, dead_code_elimination{},
eliminate_identity{},
fwd_conv_batchnorm_rewrite{}, fwd_conv_batchnorm_rewrite{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment