Unverified Commit 43230d29 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Update dead_code_elimination to remove unused modules (#820)

* Update pass manager to get modules after every pass

* Add program overload for module

* Formatting

* Hash modules for quicker lookup of modules

* Bump file version

* Add methods to remove modules

* Formatting

* Eliminate unused modules

* Formatting

* Fix test errors

* Foramtting

* Fix tidy issues
parent 6887a000
...@@ -29,14 +29,16 @@ std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last) ...@@ -29,14 +29,16 @@ std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
return -n; return -n;
} }
void dead_code_elimination::apply(module& p) const void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
void dead_code_elimination::apply(module& m) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(m.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Skip the first instruction, since we always process the previous // Skip the first instruction, since we always process the previous
// instruction // instruction
if(ins == p.begin()) if(ins == m.begin())
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
// Skip the last instruction // Skip the last instruction
...@@ -46,9 +48,9 @@ void dead_code_elimination::apply(module& p) const ...@@ -46,9 +48,9 @@ void dead_code_elimination::apply(module& p) const
if(i->get_shape().elements() == 0 and i->name().front() != '@' and if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity") i->name() != "undefined" and i->name() != "identity")
continue; continue;
assert(bidistance(p, i, last) > 0); assert(bidistance(m, i, last) > 0);
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
if(not p.has_instruction(leaf)) if(not m.has_instruction(leaf))
return; return;
if(leaf->outputs().empty()) if(leaf->outputs().empty())
...@@ -56,15 +58,15 @@ void dead_code_elimination::apply(module& p) const ...@@ -56,15 +58,15 @@ void dead_code_elimination::apply(module& p) const
std::unordered_set<instruction_ref> args(leaf->inputs().begin(), std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end()); leaf->inputs().end());
leaf->clear_arguments(); leaf->clear_arguments();
assert(bidistance(p, last, leaf) < 0); assert(bidistance(m, last, leaf) < 0);
assert(leaf != ins); assert(leaf != ins);
p.move_instruction(leaf, p.end()); m.move_instruction(leaf, m.end());
for(auto arg : args) for(auto arg : args)
self(arg); self(arg);
} }
})(i); })(i);
} }
p.remove_instructions(std::next(last), p.end()); m.remove_instructions(std::next(last), m.end());
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
struct program;
/** /**
* Remove instructions where the output is not used. * Remove instructions where the output is not used.
...@@ -16,7 +17,8 @@ struct module; ...@@ -16,7 +17,8 @@ struct module;
struct dead_code_elimination struct dead_code_elimination
{ {
std::string name() const { return "dead_code_elimination"; } std::string name() const { return "dead_code_elimination"; }
void apply(module& p) const; void apply(module& m) const;
void apply(program& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -226,6 +226,11 @@ struct id ...@@ -226,6 +226,11 @@ struct id
} }
}; };
template <class... Ts>
void nop(Ts&&...)
{
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
#include <migraphx/config.hpp>
#include <iterator>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class F>
struct function_output_iterator
{
F f;
using self = function_output_iterator;
using difference_type = void;
using reference = void;
using value_type = void;
using pointer = void;
using iterator_category = std::output_iterator_tag;
struct output_proxy
{
template <class T>
output_proxy& operator=(const T& value)
{
assert(f);
(*f)(value);
return *this;
}
F* f;
};
output_proxy operator*() { return output_proxy{&f}; }
self& operator++() { return *this; }
self& operator++(int) { return *this; } // NOLINT
};
template <class F>
function_output_iterator<F> make_function_output_iterator(F f)
{
return {std::move(f)};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OUTPUT_ITERATOR_HPP
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
#include <cassert> #include <cassert>
#include <string> #include <string>
#include <functional>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -23,8 +23,10 @@ struct pass ...@@ -23,8 +23,10 @@ struct pass
{ {
/// A unique name used to identify the pass /// A unique name used to identify the pass
std::string name() const; std::string name() const;
/// Run the pass on the module
void apply(module& m) const;
/// Run the pass on the program /// Run the pass on the program
void apply(module& p) const; void apply(program& p) const;
}; };
#else #else
...@@ -35,7 +37,8 @@ struct pass ...@@ -35,7 +37,8 @@ struct pass
* struct pass * struct pass
* { * {
* std::string name() const; * std::string name() const;
* void apply(module & p) const; * void apply(module & m) const;
* void apply(program & p) const;
* }; * };
* *
*/ */
...@@ -109,7 +112,13 @@ struct pass ...@@ -109,7 +112,13 @@ struct pass
return (*this).private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
void apply(module& p) const void apply(module& m) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(m);
}
void apply(program& p) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(p); (*this).private_detail_te_get_handle().apply(p);
...@@ -128,10 +137,37 @@ struct pass ...@@ -128,10 +137,37 @@ struct pass
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual void apply(module& p) const = 0; virtual void apply(module& m) const = 0;
virtual void apply(program& p) const = 0;
}; };
template <class T>
static auto private_detail_te_default_apply(char, T&& private_detail_te_self, module& m)
-> decltype(private_detail_te_self.apply(m))
{
private_detail_te_self.apply(m);
}
template <class T>
static void private_detail_te_default_apply(float, T&& private_detail_te_self, module& m)
{
migraphx::nop(private_detail_te_self, m);
}
template <class T>
static auto private_detail_te_default_apply(char, T&& private_detail_te_self, program& p)
-> decltype(private_detail_te_self.apply(p))
{
private_detail_te_self.apply(p);
}
template <class T>
static void private_detail_te_default_apply(float, T&& private_detail_te_self, program& p)
{
migraphx::nop(private_detail_te_self, p);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -162,7 +198,17 @@ struct pass ...@@ -162,7 +198,17 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
void apply(module& p) const override { private_detail_te_value.apply(p); } void apply(module& m) const override
{
private_detail_te_default_apply(char(0), private_detail_te_value, m);
}
void apply(program& p) const override
{
private_detail_te_default_apply(char(0), private_detail_te_value, p);
}
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void run_passes(module& modl, const std::vector<pass>& passes, tracer trace = tracer{}); void run_passes(module& mod, const std::vector<pass>& passes, tracer trace = tracer{});
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace = tracer{});
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -101,6 +101,9 @@ struct program ...@@ -101,6 +101,9 @@ struct program
std::vector<const module*> get_modules() const; std::vector<const module*> get_modules() const;
std::vector<module*> get_modules(); std::vector<module*> get_modules();
void remove_module(const std::string& name);
void remove_unused_modules();
private: private:
void assign(const program& p); void assign(const program& p);
std::unique_ptr<program_impl> impl; std::unique_ptr<program_impl> impl;
......
...@@ -168,6 +168,12 @@ void copy(Range&& r, Iterator it) ...@@ -168,6 +168,12 @@ void copy(Range&& r, Iterator it)
std::copy(r.begin(), r.end(), it); std::copy(r.begin(), r.end(), it);
} }
template <class Range>
auto reverse(Range& r)
{
return range(std::make_reverse_iterator(r.end()), std::make_reverse_iterator(r.begin()));
}
template <class Range, class T> template <class Range, class T>
void replace(Range&& r, const T& old, const T& new_x) void replace(Range&& r, const T& old, const T& new_x)
{ {
......
...@@ -405,6 +405,8 @@ instruction_ref module::end() const { return impl->instructions.end(); } ...@@ -405,6 +405,8 @@ instruction_ref module::end() const { return impl->instructions.end(); }
std::vector<shape> module::get_output_shapes() const std::vector<shape> module::get_output_shapes() const
{ {
if(impl->instructions.empty())
return {};
auto last_ins = impl->instructions.back(); auto last_ins = impl->instructions.back();
if(last_ins.name() == "@return") if(last_ins.name() == "@return")
{ {
......
...@@ -15,25 +15,56 @@ ...@@ -15,25 +15,56 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
void run_passes(module& modl, const std::vector<pass>& passes, tracer trace) void validate_pass(module& mod, const pass& p, tracer trace)
{
(void)mod;
(void)p;
(void)trace;
#ifndef NDEBUG
trace("Validate ...");
auto invalid = mod.validate();
if(invalid != mod.end())
{
auto index = std::distance(mod.begin(), invalid);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
}
trace();
#endif
}
void run_pass(module& mod, const pass& p, tracer trace)
{
trace("Module: ", mod.name(), ", Pass: ", p.name());
assert(mod.validate() == mod.end());
p.apply(mod);
trace(mod);
validate_pass(mod, p, trace);
}
void run_pass(program& prog, const pass& p, tracer trace)
{
trace("Pass: ", p.name());
p.apply(prog);
trace(prog);
}
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{ {
for(const auto& p : passes) for(const auto& p : passes)
{ {
trace("Module: ", modl.name(), ", Pass: ", p.name()); run_pass(mod, p, trace);
p.apply(modl); }
trace(modl); }
#ifndef NDEBUG void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
trace("Validate ..."); {
auto invalid = modl.validate(); for(const auto& p : passes)
if(invalid != modl.end()) {
auto mods = prog.get_modules();
for(const auto& mod : reverse(mods))
{ {
auto index = std::distance(modl.begin(), invalid); run_pass(*mod, p, trace);
MIGRAPHX_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index) + ": " + invalid->name());
} }
trace(); run_pass(prog, p, trace);
#endif
} }
} }
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/algorithm.hpp> #include <migraphx/algorithm.hpp>
#include <migraphx/output_iterator.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -144,14 +145,14 @@ void program::compile(const target& t, compile_options options) ...@@ -144,14 +145,14 @@ void program::compile(const target& t, compile_options options)
options.trace(*this); options.trace(*this);
options.trace(); options.trace();
auto mods = this->get_modules();
std::reverse(mods.begin(), mods.end());
auto&& passes = t.get_passes(this->impl->ctx, options); auto&& passes = t.get_passes(this->impl->ctx, options);
run_passes(*this, passes, options.trace);
auto mods = this->get_modules();
for(const auto& mod : mods) // Validate and finalize
for(const auto& mod : reverse(mods))
{ {
assert(mod->validate() == mod->end());
run_passes(*mod, passes, options.trace);
auto invalid = mod->validate(); auto invalid = mod->validate();
if(invalid != mod->end()) if(invalid != mod->end())
{ {
...@@ -306,7 +307,7 @@ std::vector<argument> program::eval(parameter_map params) const ...@@ -306,7 +307,7 @@ std::vector<argument> program::eval(parameter_map params) const
} }
} }
const int program_file_version = 4; const int program_file_version = 5;
value program::to_value() const value program::to_value() const
{ {
...@@ -656,6 +657,7 @@ const module* program::get_module(const std::string& name) const { return &impl- ...@@ -656,6 +657,7 @@ const module* program::get_module(const std::string& name) const { return &impl-
module* program::create_module(const std::string& name) module* program::create_module(const std::string& name)
{ {
assert(not contains(impl->modules, name));
auto r = impl->modules.emplace(name, name); auto r = impl->modules.emplace(name, name);
return &(r.first->second); return &(r.first->second);
} }
...@@ -704,6 +706,53 @@ std::vector<module*> program::get_modules() ...@@ -704,6 +706,53 @@ std::vector<module*> program::get_modules()
return result; return result;
} }
template <class Map, class T>
bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name)
{
bool is_unused = false;
generic_get_unused_modules(m, mods, make_function_output_iterator([&](auto* mod) {
if(mod->name() == name)
is_unused = true;
}));
return is_unused;
}
template <class Map>
bool references_instruction(Map& m, const instruction& ins, const std::string& name)
{
return std::any_of(m.begin(), m.end(), [&](auto&& p) {
if(p.first == name)
return false;
return std::any_of(p.second.begin(), p.second.end(), [&](auto&& i) {
return std::any_of(i.inputs().begin(), i.inputs().end(), [&](auto&& j) {
return std::addressof(*j) == std::addressof(ins);
});
});
});
}
void program::remove_module(const std::string& name)
{
// cppcheck-suppress assertWithSideEffect
assert(is_unused_module(impl->modules, generic_get_modules(this->get_main_module()), name) &&
"Module used in program");
assert(std::none_of(
impl->modules.at(name).begin(),
impl->modules.at(name).end(),
[&](auto&& ins) { return references_instruction(impl->modules, ins, name); }) &&
"Instruction referenced in another module");
impl->modules.erase(name);
}
void program::remove_unused_modules()
{
std::vector<module*> unused;
generic_get_unused_modules(
impl->modules, generic_get_modules(this->get_main_module()), std::back_inserter(unused));
for(auto* m : unused)
this->remove_module(m->name());
}
program& program::sort() program& program::sort()
{ {
for(auto& pp : this->impl->modules) for(auto& pp : this->impl->modules)
......
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <test.hpp> #include <test.hpp>
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes(*p.get_main_module(), {migraphx::dead_code_elimination{}}); migraphx::run_passes(p, {migraphx::dead_code_elimination{}});
} }
TEST_CASE(simple_test) TEST_CASE(simple_test)
...@@ -177,4 +178,21 @@ TEST_CASE(duplicate_args3) ...@@ -177,4 +178,21 @@ TEST_CASE(duplicate_args3)
EXPECT(result == migraphx::literal{0}); EXPECT(result == migraphx::literal{0});
} }
TEST_CASE(unused_module)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* m1 = p.create_module("unused");
auto* m2 = p.create_module("used");
auto l0 = mm->add_literal(0);
m1->add_literal(0);
m2->add_literal(0);
mm->add_instruction(mod_pass_op{}, {l0}, {m2});
EXPECT(migraphx::contains(p.get_modules(), m1));
EXPECT(migraphx::contains(p.get_modules(), m2));
run_pass(p);
EXPECT(migraphx::contains(p.get_modules(), m2));
EXPECT(not migraphx::contains(p.get_modules(), m1));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -3,10 +3,10 @@ ...@@ -3,10 +3,10 @@
#include <cassert> #include <cassert>
#include <string> #include <string>
#include <functional>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -23,8 +23,10 @@ struct pass ...@@ -23,8 +23,10 @@ struct pass
{ {
/// A unique name used to identify the pass /// A unique name used to identify the pass
std::string name() const; std::string name() const;
/// Run the pass on the module
void apply(module& m) const;
/// Run the pass on the program /// Run the pass on the program
void apply(module& p) const; void apply(program& p) const;
}; };
#else #else
...@@ -32,7 +34,8 @@ struct pass ...@@ -32,7 +34,8 @@ struct pass
<% <%
interface('pass', interface('pass',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', p='module &', const=True) virtual('apply', returns='void', m='module &', const=True, default='migraphx::nop'),
virtual('apply', returns='void', p='program &', const=True, default='migraphx::nop')
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment