Commit 299b33d1 authored by Paul's avatar Paul
Browse files

Add initial cse

parent 945e89e0
add_library(migraph add_library(migraph
auto_contiguous.cpp auto_contiguous.cpp
common_subexpression_elimination.cpp
constant_propagate.cpp constant_propagate.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
......
#include <migraph/common_subexpression_elimination.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/ranges.hpp>
#include <migraph/functional.hpp>
namespace migraph {
template<class Range>
void cse_range(program& p, Range&& r)
{
std::unordered_multimap<std::string, instruction_ref> instructions;
for(auto ins : r)
{
// Skip dead instructions
if(ins->outputs().empty())
continue;
// Find instruction with the same name
auto found_instructions = range(instructions.equal_range(ins->name()));
for(auto pp:found_instructions)
{
auto eq = pp.second;
if(*eq != *ins)
continue;
p.replace_instruction(ins, eq);
cse_range(p, eq->outputs());
}
instructions.emplace(ins->name(), ins);
}
}
void common_subexpression_elimination::apply(program& p) const
{
cse_range(p, iterator_for(p));
}
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct common_subexpression_elimination
{
std::string name() const { return "common_subexpression_elimination"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -43,6 +43,10 @@ struct instruction ...@@ -43,6 +43,10 @@ struct instruction
const std::vector<instruction_ref>& outputs() const; const std::vector<instruction_ref>& outputs() const;
friend bool operator==(const instruction& x, const instruction& y);
friend bool operator!=(const instruction& x, const instruction& y);
friend bool operator==(instruction_ref ref, const instruction& i); friend bool operator==(instruction_ref ref, const instruction& i);
friend bool operator!=(const instruction& i, instruction_ref ref); friend bool operator!=(const instruction& i, instruction_ref ref);
......
...@@ -92,6 +92,12 @@ iterator_range<Iterator> range(Iterator start, Iterator last) ...@@ -92,6 +92,12 @@ iterator_range<Iterator> range(Iterator start, Iterator last)
return {start, last}; return {start, last};
} }
template <class Iterator>
iterator_range<Iterator> range(std::pair<Iterator, Iterator> p)
{
return {p.first, p.second};
}
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -94,6 +94,17 @@ const std::vector<instruction_ref>& instruction::inputs() const { return argumen ...@@ -94,6 +94,17 @@ const std::vector<instruction_ref>& instruction::inputs() const { return argumen
const std::vector<instruction_ref>& instruction::outputs() const { return output; } const std::vector<instruction_ref>& instruction::outputs() const { return output; }
bool operator==(const instruction& x, const instruction& y)
{
if(not (x.result == y.result and x.op == y.op and x.arguments == y.arguments))
return false;
if(x.name() == "@literal")
return x.lit == y.lit;
return true;
}
bool operator!=(const instruction& x, const instruction& y) { return !(x == y); }
bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); } bool operator!=(const instruction& i, instruction_ref ref) { return !(i == ref); }
......
#include <migraph/common_subexpression_elimination.hpp>
#include <migraph/dead_code_elimination.hpp>
#include <migraph/operators.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct cse_target
{
std::string name() const { return "dce"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {migraph::common_subexpression_elimination{}, migraph::dead_code_elimination{}};
}
migraph::context get_context() const { return {}; }
};
void cse_test1()
{
migraph::program p1;
{
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(cse_target{});
migraph::program p2;
{
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, sum1);
p2.add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
void cse_test2()
{
migraph::program p1;
{
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, two, one);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(cse_target{});
migraph::program p2;
{
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraph::op::add{}, two, one);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, sum2);
p2.add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
void cse_test3()
{
migraph::program p1;
{
auto one = p1.add_literal(1);
auto two = p1.add_literal(1);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, two, one);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
p1.compile(cse_target{});
migraph::program p2;
{
auto one = p2.add_literal(1);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, one);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, sum1);
p2.add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
void cse_test4()
{
migraph::program p1;
{
auto one = p1.add_literal(1);
auto two = p1.add_literal(1);
auto sum1 = p1.add_instruction(migraph::op::add{}, one, two);
auto sum2 = p1.add_instruction(migraph::op::add{}, two, one);
auto sum3 = p1.add_instruction(migraph::op::add{}, sum1, one);
auto sum4 = p1.add_instruction(migraph::op::add{}, sum2, two);
auto sum5 = p1.add_instruction(migraph::op::add{}, sum4, sum3);
p1.add_instruction(pass_op{}, sum5);
}
p1.compile(cse_target{});
migraph::program p2;
{
auto one = p2.add_literal(1);
auto sum1 = p2.add_instruction(migraph::op::add{}, one, one);
auto sum3 = p2.add_instruction(migraph::op::add{}, sum1, one);
auto sum5 = p2.add_instruction(migraph::op::add{}, sum3, sum3);
p2.add_instruction(pass_op{}, sum5);
}
EXPECT(p1 == p2);
}
int main()
{
cse_test1();
cse_test2();
cse_test3();
cse_test4();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment