"git@developer.sourcefind.cn:gaoqiong/yaml-cpp.git" did not exist on "b35a332fdd5a1abadad01d41bf734e4d9552ed76"
Unverified Commit da26db34 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add flag to bypass passes on modules (#949)

Needed to bypass passes when fusing pointwise operators into a module.
parent 985f58b0
...@@ -46,6 +46,9 @@ struct module ...@@ -46,6 +46,9 @@ struct module
std::string name() const; std::string name() const;
bool bypass() const;
void set_bypass(bool b = true);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)> template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref add_instruction(operation op, Ts... args) instruction_ref add_instruction(operation op, Ts... args)
{ {
......
...@@ -28,6 +28,7 @@ struct module_impl ...@@ -28,6 +28,7 @@ struct module_impl
std::unordered_set<instruction*> instruction_set; std::unordered_set<instruction*> instruction_set;
std::string name; std::string name;
uint32_t nparams = 0; uint32_t nparams = 0;
bool bypass = false;
bool contains(instruction_ref ins) const bool contains(instruction_ref ins) const
{ {
...@@ -49,6 +50,13 @@ struct module_impl ...@@ -49,6 +50,13 @@ struct module_impl
return emplace(pos, ins); return emplace(pos, ins);
} }
void clear()
{
instructions.clear();
instruction_set.clear();
nparams = 0;
}
void push_front(const instruction& ins) { insert(instructions.begin(), ins); } void push_front(const instruction& ins) { insert(instructions.begin(), ins); }
void push_back(const instruction& ins) { insert(instructions.end(), ins); } void push_back(const instruction& ins) { insert(instructions.end(), ins); }
...@@ -100,18 +108,21 @@ module& module::operator=(module m) ...@@ -100,18 +108,21 @@ module& module::operator=(module m)
std::string module::name() const { return impl->name; } std::string module::name() const { return impl->name; }
bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; }
void module::assign(const module& m) void module::assign(const module& m)
{ {
// clean the current module // copy the impl
if(!impl) if(!impl)
{
impl = std::make_unique<module_impl>(); impl = std::make_unique<module_impl>();
} *impl = *m.impl;
else if(!impl->instructions.empty())
// clear instructions
if(!impl->instructions.empty())
{ {
impl->instructions.clear(); impl->clear();
} }
impl->name = m.impl->name;
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
......
...@@ -95,6 +95,8 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) ...@@ -95,6 +95,8 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
auto mods = prog.get_modules(); auto mods = prog.get_modules();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
if(mod->bypass())
continue;
module_pm{mod, &prog, &trace}.run_pass(p); module_pm{mod, &prog, &trace}.run_pass(p);
} }
run_pass(prog, p, trace); run_pass(prog, p, trace);
......
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/ref/target.hpp> #include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <sstream> #include <sstream>
...@@ -276,4 +277,39 @@ TEST_CASE(parameter_name_order) ...@@ -276,4 +277,39 @@ TEST_CASE(parameter_name_order)
EXPECT(param_names == names1); EXPECT(param_names == names1);
} }
struct check_for_pass_op
{
bool* found = nullptr;
std::string name() const { return "check_for_pass_op"; }
void apply(migraphx::module& m) const
{
*found |= std::any_of(m.begin(), m.end(), [](auto&& ins) { return ins.name() == "pass"; });
}
};
TEST_CASE(module_bypass)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* sub = p.create_module("sub");
sub->set_bypass();
sub->add_instruction(pass_op{});
mm->add_instruction(mod_pass_op{}, {}, {sub});
bool found = false;
migraphx::run_passes(p, {check_for_pass_op{&found}});
EXPECT(not found);
}
TEST_CASE(module_without_bypass)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto* sub = p.create_module("sub");
sub->add_instruction(pass_op{});
mm->add_instruction(mod_pass_op{}, {}, {sub});
bool found = false;
migraphx::run_passes(p, {check_for_pass_op{&found}});
EXPECT(found);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment