Commit a6d98bad authored by Paul's avatar Paul
Browse files

Fix replacement instruction

parent bf3789fa
...@@ -56,7 +56,6 @@ struct instruction ...@@ -56,7 +56,6 @@ struct instruction
{ {
std::replace(arguments.begin(), arguments.end(), old, new_ins); std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this); old->remove_output(*this);
recompute_shape();
} }
void clear_arguments() void clear_arguments()
...@@ -142,6 +141,13 @@ inline void backreference(instruction_ref ref) ...@@ -142,6 +141,13 @@ inline void backreference(instruction_ref ref)
arg->add_output(ref); arg->add_output(ref);
} }
inline void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
// TODO: Move to a cpp file // TODO: Move to a cpp file
// TODO: Use const ref for vector // TODO: Use const ref for vector
inline shape compute_shape(operation op, std::vector<instruction_ref> args) inline shape compute_shape(operation op, std::vector<instruction_ref> args)
......
...@@ -52,10 +52,7 @@ struct program ...@@ -52,10 +52,7 @@ struct program
instruction_ref instruction_ref
replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args); replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
instruction_ref instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
replace_instructions(instruction_ref ins, instruction_ref start, instruction_ref last);
instruction_ref replace_instruction(instruction_ref ins, instruction_ref start);
instruction_ref remove_instruction(instruction_ref ins); instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last); instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
...@@ -74,7 +71,7 @@ struct program ...@@ -74,7 +71,7 @@ struct program
instruction_ref add_parameter(std::string name, shape s); instruction_ref add_parameter(std::string name, shape s);
shape get_parameter_shape(std::string name); shape get_parameter_shape(std::string name) const;
std::unordered_map<std::string, shape> get_parameter_shapes() const; std::unordered_map<std::string, shape> get_parameter_shapes() const;
...@@ -82,8 +79,8 @@ struct program ...@@ -82,8 +79,8 @@ struct program
bool has_instruction(instruction_ref ins) const; bool has_instruction(instruction_ref ins) const;
instruction_ref begin(); instruction_ref begin() const;
instruction_ref end(); instruction_ref end() const;
shape get_shape() const; shape get_shape() const;
......
...@@ -57,32 +57,35 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst ...@@ -57,32 +57,35 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
return ins; return ins;
} }
instruction_ref instruction_ref program::replace_instruction(instruction_ref ins, instruction_ref rep)
program::replace_instructions(instruction_ref ins, instruction_ref start, instruction_ref last)
{ {
auto rep = std::prev(last); assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);
// TODO: Should it be an error if the output is empty?
if(ins->output.empty())
{
remove_instruction(ins);
return rep;
}
for(auto&& out : ins->output) for(auto&& out : ins->output)
{ {
if(std::find(start, last, out) == last) // TODO: Check for possible cycles
if(out != rep)
{ {
out->replace_argument(ins, rep); replace_argument(out, ins, rep);
backreference(out);
} }
assert(out->valid(begin())); assert(out->valid(begin()));
} }
assert(rep->valid(begin())); // Replacement should not be dead code unless its the last instruction
assert(!rep->output.empty() or rep == std::prev(end()));
assert(ins->valid(begin())); assert(ins->valid(begin()));
if(ins->output.empty()) if(ins->output.empty())
remove_instruction(ins); remove_instruction(ins);
assert(rep->valid(begin()));
return rep; return rep;
} }
instruction_ref program::replace_instruction(instruction_ref ins, instruction_ref start)
{
assert(ins != start);
return replace_instructions(ins, start, std::next(start));
}
instruction_ref program::remove_instruction(instruction_ref ins) instruction_ref program::remove_instruction(instruction_ref ins)
{ {
assert(has_instruction(ins)); assert(has_instruction(ins));
...@@ -126,7 +129,7 @@ instruction_ref program::add_parameter(std::string name, shape s) ...@@ -126,7 +129,7 @@ instruction_ref program::add_parameter(std::string name, shape s)
return impl->instructions.begin(); return impl->instructions.begin();
} }
shape program::get_parameter_shape(std::string name) shape program::get_parameter_shape(std::string name) const
{ {
auto ins = std::find_if( auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) { impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
...@@ -167,8 +170,8 @@ bool program::has_instruction(instruction_ref ins) const ...@@ -167,8 +170,8 @@ bool program::has_instruction(instruction_ref ins) const
}) != impl->instructions.end(); }) != impl->instructions.end();
} }
instruction_ref program::begin() { return impl->instructions.begin(); } instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() { return impl->instructions.end(); } instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().result; } shape program::get_shape() const { return impl->instructions.back().result; }
......
...@@ -112,6 +112,7 @@ void replace_test() ...@@ -112,6 +112,7 @@ void replace_test()
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.replace_instruction(sum, minus_op{}, two, one); p.replace_instruction(sum, minus_op{}, two, one);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{1}); EXPECT(result == migraph::literal{1});
...@@ -127,6 +128,7 @@ void replace_ins_test() ...@@ -127,6 +128,7 @@ void replace_ins_test()
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
auto minus = p.add_instruction(minus_op{}, two, one); auto minus = p.add_instruction(minus_op{}, two, one);
p.replace_instruction(sum, minus); p.replace_instruction(sum, minus);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{1}); EXPECT(result == migraph::literal{1});
...@@ -140,23 +142,10 @@ void replace_ins_test2() ...@@ -140,23 +142,10 @@ void replace_ins_test2()
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, two, one); auto minus = p.add_instruction(minus_op{}, two, one);
p.add_instruction(pass_op{}, minus);
p.replace_instruction(two, sum); p.replace_instruction(two, sum);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result == migraph::literal{2});
EXPECT(result != migraph::literal{3});
}
void replace_inss_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, two, one);
p.replace_instructions(two, two, std::next(sum));
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{2}); EXPECT(result == migraph::literal{2});
...@@ -174,6 +163,7 @@ void insert_replace_test() ...@@ -174,6 +163,7 @@ void insert_replace_test()
auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two); auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
p.replace_instruction(sum1, minus_op{}, sum0, two); p.replace_instruction(sum1, minus_op{}, sum0, two);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({}); auto result = p.eval({});
EXPECT(result == migraph::literal{4}); EXPECT(result == migraph::literal{4});
...@@ -228,7 +218,6 @@ int main() ...@@ -228,7 +218,6 @@ int main()
replace_test(); replace_test();
replace_ins_test(); replace_ins_test();
replace_ins_test2(); replace_ins_test2();
replace_inss_test();
insert_replace_test(); insert_replace_test();
target_test(); target_test();
reverse_target_test(); reverse_target_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment