Commit fd150551 authored by Khalique's avatar Khalique
Browse files

added tests, adjusted pass to eliminate identities only

parent 0ebe4abe
...@@ -10,13 +10,59 @@ ...@@ -10,13 +10,59 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Range, class Iterator>
std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
{
auto start_forward = start;
auto start_backwards = start;
std::size_t n = 0;
while(start_forward != last and start_backwards != last)
{
n++;
if(start_forward != r.end())
start_forward++;
if(start_backwards != r.begin())
start_backwards--;
}
if(start_forward == last)
return n;
else
return -n;
}
void eliminate_identity::apply(program& p) const void eliminate_identity::apply(program& p) const
{ {
auto last = std::prev(p.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if(ins->name() == "identity") // Skip the first instruction, since we always process the previous
p.replace_instruction(ins, ins->inputs().front()); // instruction
if(ins == p.begin())
continue;
const auto i = std::prev(ins);
if(i->name() == "identity")
{
p.replace_instruction(i, i->inputs().front());
p.move_instruction(i, p.end());
}
if(ins == last)
{
if(ins->name() == "identity")
{
const instruction_ref& identity_input = i->inputs().front();
if(identity_input->outputs().size() == 1)
{
p.move_instruction(identity_input, i);
// since this is the last instruction, removing it only
// requires changing "last" and calling remove below
last = std::prev(last);
}
}
break;
}
} }
p.remove_instructions(std::next(last), p.end());
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -61,7 +61,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -61,7 +61,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_workspace{}, eliminate_workspace{},
eliminate_allocation{"hip::allocate"}, eliminate_allocation{"hip::allocate"},
check_context<context>{}, check_context<context>{},
dead_code_elimination{} dead_code_elimination{},
eliminate_identity{}
}; };
// clang-format on // clang-format on
} }
......
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <migraphx/operators.hpp>
#include <test.hpp>
struct eliminate_identity_target
{
std::string name() const { return "eliminate_identity"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::eliminate_identity{}};
}
migraphx::context get_context() const { return {}; }
};
TEST_CASE(simple_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto one_identity = p.add_instruction(migraphx::op::identity{}, one);
auto two = p.add_literal(2);
auto two_identity = p.add_instruction(migraphx::op::identity{}, two);
p.add_instruction(sum_op{}, one_identity, two_identity);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins){ return ins.name() == "identity"; }));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins){ return ins.name() == "identity"; }));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3});
}
TEST_CASE(simple_test_end_dependency)
{
migraphx::program p;
auto one = p.add_literal(1.0);
auto two = p.add_literal(2.0);
auto three = p.add_literal(3.0);
auto ans = p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, ans, three);
p.add_instruction(migraphx::op::identity{}, ans);
p.compile(eliminate_identity_target{});
EXPECT(!std::none_of(p.begin(), p.end(), [](const migraphx::instruction& ins){ return ins.name() == "identity"; }));
auto result = p.eval({});
EXPECT(result == migraphx::literal{3.0});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment