Commit 1290f3ba authored by Paul's avatar Paul
Browse files

Fix bug in cse pass

parent 015631a1
...@@ -4,7 +4,7 @@ include(ROCMPackageConfigHelpers) ...@@ -4,7 +4,7 @@ include(ROCMPackageConfigHelpers)
add_library(migraphx add_library(migraphx
auto_contiguous.cpp auto_contiguous.cpp
common_subexpression_elimination.cpp eliminate_common_subexpression.cpp
propagate_constant.cpp propagate_constant.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
......
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
...@@ -27,13 +27,17 @@ void cse_range(program& p, Range&& r) ...@@ -27,13 +27,17 @@ void cse_range(program& p, Range&& r)
if(*eq != *ins) if(*eq != *ins)
continue; continue;
p.replace_instruction(ins, eq); p.replace_instruction(ins, eq);
cse_range(p, eq->outputs()); auto outputs = eq->outputs();
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y);
});
cse_range(p, outputs);
} }
instructions.emplace(ins->name(), ins); instructions.emplace(ins->name(), ins);
} }
} }
void common_subexpression_elimination::apply(program& p) const { cse_range(p, iterator_for(p)); } void eliminate_common_subexpression::apply(program& p) const { cse_range(p, iterator_for(p)); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -13,9 +13,9 @@ struct program; ...@@ -13,9 +13,9 @@ struct program;
/** /**
* Remove identical instructions. * Remove identical instructions.
*/ */
struct common_subexpression_elimination struct eliminate_common_subexpression
{ {
std::string name() const { return "common_subexpression_elimination"; } std::string name() const { return "eliminate_common_subexpression"; }
void apply(program& p) const; void apply(program& p) const;
}; };
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
...@@ -49,8 +49,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -49,8 +49,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
rewrite_rnn{}, rewrite_rnn{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
// common_subexpression_elimination{}, eliminate_common_subexpression{},
// dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
......
#include <migraphx/common_subexpression_elimination.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -9,7 +9,7 @@ struct cse_target ...@@ -9,7 +9,7 @@ struct cse_target
std::string name() const { return "dce"; } std::string name() const { return "dce"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraphx::common_subexpression_elimination{}, migraphx::dead_code_elimination{}}; return {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}};
} }
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment