Unverified Commit c2842c1e authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix invalid program in debug mode from find_splits (#1390)

* Fix invalid program from find_splits
parent 70e63960
...@@ -385,9 +385,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds ...@@ -385,9 +385,13 @@ instruction_ref module::move_instruction(instruction_ref src, instruction_ref ds
instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst) instruction_ref module::move_instructions(instruction_ref src, instruction_ref dst)
{ {
this->move_instruction(src, dst);
for(auto ins : src->inputs()) for(auto ins : src->inputs())
this->move_instruction(ins, src); {
if(not contains(this->impl->instructions, ins))
continue;
this->move_instructions(ins, dst);
}
this->move_instruction(src, dst);
return src; return src;
} }
......
...@@ -435,6 +435,24 @@ struct find_concat_op ...@@ -435,6 +435,24 @@ struct find_concat_op
} }
}; };
void move_instructions_back(module& m, instruction_ref pos, std::vector<instruction_ref> inss)
{
auto start = range(m.begin(), pos);
for(auto ins : iterator_for(start))
{
auto it = std::find(inss.begin(), inss.end(), ins);
if(it != inss.end())
inss.erase(it);
}
for(auto ins : inss)
{
if(not m.has_instruction(ins))
continue;
move_instructions_back(m, pos, ins->inputs());
m.move_instruction(ins, pos);
}
}
std::vector<instruction_ref> get_splits(instruction_ref ins) std::vector<instruction_ref> get_splits(instruction_ref ins)
{ {
std::vector<instruction_ref> result; std::vector<instruction_ref> result;
...@@ -610,8 +628,7 @@ struct find_splits ...@@ -610,8 +628,7 @@ struct find_splits
})) }))
return; return;
for(auto data : data_args) move_instructions_back(m, ins, data_args);
m.move_instructions(data, ins);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator()); auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty()); assert(not slice_op.axes.empty());
...@@ -864,8 +881,7 @@ struct find_conv_dot_horiz_fusion ...@@ -864,8 +881,7 @@ struct find_conv_dot_horiz_fusion
concat_axis = axis; concat_axis = axis;
} }
for(auto arg : args) move_instructions_back(m, input, args);
m.move_instructions(arg, input);
// TODO: Check if axes match // TODO: Check if axes match
auto concat = auto concat =
m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args); m.insert_instruction(input, make_op("concat", {{"axis", concat_axis}}), args);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment