Unverified Commit b5ba22ae authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #349 from ROCmSoftwarePlatform/cse-fix

Fix bug in cse pass
parents 92c62bd0 99bcbf02
......@@ -4,7 +4,7 @@ include(ROCMPackageConfigHelpers)
add_library(migraphx
auto_contiguous.cpp
common_subexpression_elimination.cpp
eliminate_common_subexpression.cpp
propagate_constant.cpp
dead_code_elimination.cpp
eliminate_allocation.cpp
......
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
......@@ -27,13 +27,17 @@ void cse_range(program& p, Range&& r)
if(*eq != *ins)
continue;
p.replace_instruction(ins, eq);
cse_range(p, eq->outputs());
auto outputs = eq->outputs();
std::sort(outputs.begin(), outputs.end(), [&](auto x, auto y) {
return std::distance(eq, x) < std::distance(eq, y);
});
cse_range(p, outputs);
}
instructions.emplace(ins->name(), ins);
}
}
void common_subexpression_elimination::apply(program& p) const { cse_range(p, iterator_for(p)); }
void eliminate_common_subexpression::apply(program& p) const { cse_range(p, iterator_for(p)); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -13,9 +13,9 @@ struct program;
/**
* Remove identical instructions.
*/
struct common_subexpression_elimination
struct eliminate_common_subexpression
{
std::string name() const { return "common_subexpression_elimination"; }
std::string name() const { return "eliminate_common_subexpression"; }
void apply(program& p) const;
};
......
......@@ -13,7 +13,7 @@
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
......@@ -49,8 +49,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
rewrite_rnn{},
rewrite_pooling{},
dead_code_elimination{},
// common_subexpression_elimination{},
// dead_code_elimination{},
eliminate_common_subexpression{},
dead_code_elimination{},
simplify_algebra{},
dead_code_elimination{},
auto_contiguous{},
......
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/op/add.hpp>
#include <basic_ops.hpp>
......@@ -9,7 +9,7 @@ struct cse_target
std::string name() const { return "dce"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::common_subexpression_elimination{}, migraphx::dead_code_elimination{}};
return {migraphx::eliminate_common_subexpression{}, migraphx::dead_code_elimination{}};
}
migraphx::context get_context() const { return {}; }
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment