"...resnet50_tensorflow.git" did not exist on "6ee54a60f61b0a639dfa855009c6abc3d51f4d92"
Commit 375e4c15 authored by Paul's avatar Paul
Browse files

Add dead-code elimination pass

parent 603adbe6
add_library(migraph add_library(migraph
dead_code_elimination.cpp
generate.cpp generate.cpp
program.cpp program.cpp
shape.cpp shape.cpp
......
#include <migraph/dead_code_elimination.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/functional.hpp>
namespace migraph {
void dead_code_elimination::apply(program& p) const
{
for(auto i:iterator_for(p))
{
// Skip over instructions that may have been removed
if(!p.has_instruction(i))
continue;
// Skip the last instruction
if(i == std::prev(p.end()))
break;
fix([&](auto self, auto ins) {
assert(p.has_instruction(ins));
if(ins->output.empty())
{
std::cout << p << std::endl;
auto args = ins->arguments;
p.remove_instruction(ins);
for(auto arg:args)
self(arg);
}
})(i);
}
}
} // namespace migraph
#ifndef MIGRAPH_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
#define MIGRAPH_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct dead_code_elimination
{
std::string name() const { return "dead_code_elimination"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#define MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#include <utility>
namespace migraph {
namespace detail {
template<class R, class F>
struct fix_f
{
F f;
template<class... Ts>
R operator()(Ts&&... xs) const
{
return f(*this, std::forward<Ts>(xs)...);
}
};
} // namespace detail
/// Implements a fix-point combinator
template<class R, class F>
detail::fix_f<R, F> fix(F f)
{
return {f};
}
template<class F>
auto fix(F f)
{
return fix<void>(f);
}
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP #ifndef MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP #define MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#include <cassert>
#include <type_traits>
namespace migraph { namespace migraph {
template <class T> template <class T>
struct iterator_for_range struct iterator_for_range
{ {
T* base; T* base;
using base_iterator = decltype(base->begin()); using base_iterator = std::remove_reference_t<decltype(base->begin())>;
struct iterator struct iterator
{ {
...@@ -17,8 +20,8 @@ struct iterator_for_range ...@@ -17,8 +20,8 @@ struct iterator_for_range
bool operator!=(const iterator& rhs) { return i != rhs.i; } bool operator!=(const iterator& rhs) { return i != rhs.i; }
}; };
iterator begin() { return {base->begin()}; } iterator begin() { assert(base != nullptr); return {base->begin()}; }
iterator end() { return {base->end()}; } iterator end() { assert(base != nullptr); return {base->end()}; }
}; };
template <class T> template <class T>
iterator_for_range<T> iterator_for(T& x) iterator_for_range<T> iterator_for(T& x)
......
...@@ -52,6 +52,8 @@ struct program ...@@ -52,6 +52,8 @@ struct program
instruction_ref instruction_ref
replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args); replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
instruction_ref remove_instruction(instruction_ref ins);
template <class... Ts> template <class... Ts>
instruction_ref add_literal(Ts&&... xs) instruction_ref add_literal(Ts&&... xs)
{ {
......
...@@ -52,6 +52,15 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst ...@@ -52,6 +52,15 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
return ins; return ins;
} }
instruction_ref
program::remove_instruction(instruction_ref ins)
{
assert(has_instruction(ins));
assert(ins->output.empty());
ins->clear_arguments();
return impl->instructions.erase(ins);
}
instruction_ref program::add_literal(literal l) instruction_ref program::add_literal(literal l)
{ {
impl->instructions.emplace_front(std::move(l)); impl->instructions.emplace_front(std::move(l));
......
#include <migraph/dead_code_elimination.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
struct dce_target
{
std::string name() const { return "dce"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return { migraph::dead_code_elimination{} }; }
migraph::context get_context() const { return {}; }
};
void simple_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == count);
auto result = p.eval({});
EXPECT(result == migraph::literal{3});
EXPECT(result != migraph::literal{4});
}
void duplicate_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
p.add_instruction(sum_op{}, one, two);
auto count = std::distance(p.begin(), p.end());
p.compile(dce_target{});
EXPECT(std::distance(p.begin(), p.end()) == (count-1));
auto result = p.eval({});
EXPECT(result == migraph::literal{3});
EXPECT(result != migraph::literal{4});
}
int main()
{
simple_test();
duplicate_test();
}
#include <migraph/program.hpp> #include <migraph/program.hpp>
#include <migraph/argument.hpp>
#include <migraph/shape.hpp>
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp>
struct sum_op
{
std::string name() const { return "sum"; }
migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
{
migraph::argument result;
if(args.size() != 2)
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape() != args[1].get_shape())
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape().lens().size() != 1)
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape().lens().front() != 1)
MIGRAPH_THROW("Wrong args");
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { result = migraph::literal{x + y}.get_argument(); });
});
return result;
}
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
if(inputs.size() != 2)
MIGRAPH_THROW("Wrong inputs");
return inputs.front();
}
};
struct minus_op
{
std::string name() const { return "minus"; }
migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
{
migraph::argument result;
if(args.size() != 2)
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape() != args[1].get_shape())
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape().lens().size() != 1)
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape().lens().front() != 1)
MIGRAPH_THROW("Wrong args");
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { result = migraph::literal{x - y}.get_argument(); });
});
return result;
}
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
if(inputs.size() != 2)
MIGRAPH_THROW("Wrong inputs");
return inputs.front();
}
};
struct id_target struct id_target
{ {
......
#include <migraph/program.hpp>
#include <migraph/argument.hpp>
#include <migraph/shape.hpp>
struct sum_op
{
std::string name() const { return "sum"; }
migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
{
migraph::argument result;
if(args.size() != 2)
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape() != args[1].get_shape())
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape().lens().size() != 1)
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape().lens().front() != 1)
MIGRAPH_THROW("Wrong args");
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { result = migraph::literal{x + y}.get_argument(); });
});
return result;
}
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
if(inputs.size() != 2)
MIGRAPH_THROW("Wrong inputs");
return inputs.front();
}
};
struct minus_op
{
std::string name() const { return "minus"; }
migraph::argument
compute(migraph::context&, migraph::shape, std::vector<migraph::argument> args) const
{
migraph::argument result;
if(args.size() != 2)
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape() != args[1].get_shape())
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape().lens().size() != 1)
MIGRAPH_THROW("Wrong args");
if(args[0].get_shape().lens().front() != 1)
MIGRAPH_THROW("Wrong args");
args[0].visit_at([&](auto x) {
args[1].visit_at([&](auto y) { result = migraph::literal{x - y}.get_argument(); });
});
return result;
}
migraph::shape compute_shape(std::vector<migraph::shape> inputs) const
{
if(inputs.size() != 2)
MIGRAPH_THROW("Wrong inputs");
return inputs.front();
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment