Unverified Commit e758d457 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fuse last instruction in fuse_pointwise (#1015)

Fuse last instruction in fuse_pointwise
This is also fixes a bug with using an invalid iterator.
parent 00bfed4d
...@@ -126,22 +126,25 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins, ...@@ -126,22 +126,25 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
static bool find_pointwise_modules(module& m) static bool find_pointwise_modules(module& m)
{ {
bool changed = false; bool changed = false;
auto last = std::prev(m.end());
for(auto ins : iterator_for(m)) for(auto ins : iterator_for(m))
{ {
if(ins->name() != "pointwise") if(ins->name() != "pointwise")
continue; continue;
if(ins->outputs().empty()) if(ins->outputs().empty() and ins != last)
continue; continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) { auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1; return i->name() == "pointwise" and i->outputs().size() == 1;
}); });
if(it == ins->inputs().end()) if(it == ins->inputs().end())
continue; continue;
auto input = *it;
auto new_inputs = append_pointwise_module(input, ins);
m.replace_instruction(input, input->get_operator(), new_inputs, input->module_inputs());
m.replace_instruction(ins, input);
m.move_instruction(input, ins);
auto new_inputs = append_pointwise_module(*it, ins);
m.replace_instruction(*it, (*it)->get_operator(), new_inputs, (*it)->module_inputs());
m.replace_instruction(ins, *it);
m.move_instruction(*it, ins);
changed = true; changed = true;
} }
return changed; return changed;
......
...@@ -179,6 +179,7 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -179,6 +179,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op, const operation& op,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
auto result = impl->insert(ins, {op, r, std::move(args)}); auto result = impl->insert(ins, {op, r, std::move(args)});
...@@ -200,6 +201,7 @@ instruction_ref module::insert_instruction(instruction_ref ins, ...@@ -200,6 +201,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
std::vector<instruction_ref> args, std::vector<instruction_ref> args,
std::vector<module_ref> module_args) std::vector<module_ref> module_args)
{ {
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args); auto out_shape = compute_shape(op, args, module_args);
auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)}); auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
...@@ -212,6 +214,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, ...@@ -212,6 +214,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op, const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{ {
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
...@@ -225,6 +228,7 @@ instruction_ref module::replace_instruction(instruction_ref ins, ...@@ -225,6 +228,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
std::vector<instruction_ref> args, std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST
{ {
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args); auto out_shape = compute_shape(op, args, module_args);
instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args)); instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args));
...@@ -291,6 +295,8 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r ...@@ -291,6 +295,8 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r
instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst) instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst)
{ {
assert(has_instruction(src));
assert(has_instruction(dst) or is_end(dst, this->end()));
impl->instructions.splice(dst, impl->instructions, src); impl->instructions.splice(dst, impl->instructions, src);
return src; return src;
} }
......
...@@ -73,6 +73,35 @@ TEST_CASE(double_add) ...@@ -73,6 +73,35 @@ TEST_CASE(double_add)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(double_add_without_return)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_instruction(migraphx::make_op("add"), add1, z);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_instruction(migraphx::make_op("identity"), fadd);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(used_twice_not_fused) TEST_CASE(used_twice_not_fused)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment