"magic_pdf/vscode:/vscode.git/clone" did not exist on "3fb325dd2278d38576a30eb4c9470c0727a5da9c"
Unverified Commit e758d457 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fuse last instruction in fuse_pointwise (#1015)

Fuse last instruction in fuse_pointwise
This is also fixes a bug with using an invalid iterator.
parent 00bfed4d
......@@ -126,22 +126,25 @@ static std::vector<instruction_ref> append_pointwise_module(instruction_ref ins,
static bool find_pointwise_modules(module& m)
{
bool changed = false;
auto last = std::prev(m.end());
for(auto ins : iterator_for(m))
{
if(ins->name() != "pointwise")
continue;
if(ins->outputs().empty())
if(ins->outputs().empty() and ins != last)
continue;
auto it = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->name() == "pointwise" and i->outputs().size() == 1;
});
if(it == ins->inputs().end())
continue;
auto input = *it;
auto new_inputs = append_pointwise_module(input, ins);
m.replace_instruction(input, input->get_operator(), new_inputs, input->module_inputs());
m.replace_instruction(ins, input);
m.move_instruction(input, ins);
auto new_inputs = append_pointwise_module(*it, ins);
m.replace_instruction(*it, (*it)->get_operator(), new_inputs, (*it)->module_inputs());
m.replace_instruction(ins, *it);
m.move_instruction(*it, ins);
changed = true;
}
return changed;
......
......@@ -179,6 +179,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args)
{
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
auto result = impl->insert(ins, {op, r, std::move(args)});
......@@ -200,6 +201,7 @@ instruction_ref module::insert_instruction(instruction_ref ins,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args)
{
assert(has_instruction(ins) or is_end(ins, this->end()));
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
auto result = impl->insert(ins, {op, out_shape, std::move(args), std::move(module_args)});
......@@ -212,6 +214,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST
{
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@"));
shape r = compute_shape(op, args);
......@@ -225,6 +228,7 @@ instruction_ref module::replace_instruction(instruction_ref ins,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST
{
assert(has_instruction(ins));
assert(not starts_with(op.name(), "@"));
auto out_shape = compute_shape(op, args, module_args);
instruction::replace(ins, op, out_shape, std::move(args), std::move(module_args));
......@@ -291,6 +295,8 @@ instruction_ref module::remove_instructions(instruction_ref first, instruction_r
instruction_ref module::move_instruction(instruction_ref src, instruction_ref dst)
{
assert(has_instruction(src));
assert(has_instruction(dst) or is_end(dst, this->end()));
impl->instructions.splice(dst, impl->instructions, src);
return src;
}
......
......@@ -73,6 +73,35 @@ TEST_CASE(double_add)
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(double_add_without_return)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto add1 = mm->add_instruction(migraphx::make_op("add"), x, y);
mm->add_instruction(migraphx::make_op("add"), add1, z);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto z = mm->add_parameter("z", s);
auto fadd =
add_pointwise(p2, "main:pointwise0", {x, y, z}, [=](auto* pm, const auto& inputs) {
auto add1 = pm->add_instruction(migraphx::make_op("add"), inputs[0], inputs[1]);
return pm->add_instruction(migraphx::make_op("add"), add1, inputs[2]);
});
mm->add_instruction(migraphx::make_op("identity"), fadd);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(used_twice_not_fused)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment