Unverified Commit da26db34 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add flag to bypass passes on modules (#949)

Needed to bypass passes when fusing pointwise operators into a module.
parent 985f58b0
......@@ -46,6 +46,9 @@ struct module
std::string name() const;
bool bypass() const;
void set_bypass(bool b = true);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref add_instruction(operation op, Ts... args)
{
......
......@@ -28,6 +28,7 @@ struct module_impl
std::unordered_set<instruction*> instruction_set;
std::string name;
uint32_t nparams = 0;
bool bypass = false;
bool contains(instruction_ref ins) const
{
......@@ -49,6 +50,13 @@ struct module_impl
return emplace(pos, ins);
}
void clear()
{
instructions.clear();
instruction_set.clear();
nparams = 0;
}
void push_front(const instruction& ins) { insert(instructions.begin(), ins); }
void push_back(const instruction& ins) { insert(instructions.end(), ins); }
......@@ -100,18 +108,21 @@ module& module::operator=(module m)
std::string module::name() const { return impl->name; }
bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; }
void module::assign(const module& m)
{
// clean the current module
// copy the impl
if(!impl)
{
impl = std::make_unique<module_impl>();
}
else if(!impl->instructions.empty())
*impl = *m.impl;
// clear instructions
if(!impl->instructions.empty())
{
impl->instructions.clear();
impl->clear();
}
impl->name = m.impl->name;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(m))
......
......@@ -95,6 +95,8 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
auto mods = prog.get_modules();
for(const auto& mod : reverse(mods))
{
if(mod->bypass())
continue;
module_pm{mod, &prog, &trace}.run_pass(p);
}
run_pass(prog, p, trace);
......
#include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp>
#include <sstream>
......@@ -276,4 +277,39 @@ TEST_CASE(parameter_name_order)
EXPECT(param_names == names1);
}
struct check_for_pass_op
{
bool* found = nullptr;
std::string name() const { return "check_for_pass_op"; }
void apply(migraphx::module& m) const
{
*found |= std::any_of(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "pass"; });
}
};
TEST_CASE(module_bypass)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* sub = p.create_module("sub");
sub->set_bypass();
sub->add_instruction(pass_op{});
mm->add_instruction(mod_pass_op{}, {}, {sub});
bool found = false;
migraphx::run_passes(p, {check_for_pass_op{&found}});
EXPECT(not found);
}
TEST_CASE(module_without_bypass)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* sub = p.create_module("sub");
sub->add_instruction(pass_op{});
mm->add_instruction(mod_pass_op{}, {}, {sub});
bool found = false;
migraphx::run_passes(p, {check_for_pass_op{&found}});
EXPECT(found);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment