Commit a6d98bad authored by Paul's avatar Paul
Browse files

Fix replacement instruction

parent bf3789fa
......@@ -56,7 +56,6 @@ struct instruction
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
recompute_shape();
}
void clear_arguments()
......@@ -142,6 +141,13 @@ inline void backreference(instruction_ref ref)
arg->add_output(ref);
}
inline void replace_argument(instruction_ref ins, instruction_ref old, instruction_ref new_ins)
{
ins->replace_argument(old, new_ins);
backreference(ins);
ins->recompute_shape();
}
// TODO: Move to a cpp file
// TODO: Use const ref for vector
inline shape compute_shape(operation op, std::vector<instruction_ref> args)
......
......@@ -52,10 +52,7 @@ struct program
instruction_ref
replace_instruction(instruction_ref ins, operation op, std::vector<instruction_ref> args);
instruction_ref
replace_instructions(instruction_ref ins, instruction_ref start, instruction_ref last);
instruction_ref replace_instruction(instruction_ref ins, instruction_ref start);
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
......@@ -74,7 +71,7 @@ struct program
instruction_ref add_parameter(std::string name, shape s);
shape get_parameter_shape(std::string name);
shape get_parameter_shape(std::string name) const;
std::unordered_map<std::string, shape> get_parameter_shapes() const;
......@@ -82,8 +79,8 @@ struct program
bool has_instruction(instruction_ref ins) const;
instruction_ref begin();
instruction_ref end();
instruction_ref begin() const;
instruction_ref end() const;
shape get_shape() const;
......
......@@ -57,32 +57,35 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
return ins;
}
instruction_ref
program::replace_instructions(instruction_ref ins, instruction_ref start, instruction_ref last)
instruction_ref program::replace_instruction(instruction_ref ins, instruction_ref rep)
{
auto rep = std::prev(last);
assert(has_instruction(ins));
assert(has_instruction(rep));
assert(ins != rep);
// TODO: Should it be an error if the output is empty?
if(ins->output.empty())
{
remove_instruction(ins);
return rep;
}
for(auto&& out : ins->output)
{
if(std::find(start, last, out) == last)
// TODO: Check for possible cycles
if(out != rep)
{
out->replace_argument(ins, rep);
backreference(out);
replace_argument(out, ins, rep);
}
assert(out->valid(begin()));
}
assert(rep->valid(begin()));
// Replacement should not be dead code unless its the last instruction
assert(!rep->output.empty() or rep == std::prev(end()));
assert(ins->valid(begin()));
if(ins->output.empty())
remove_instruction(ins);
assert(rep->valid(begin()));
return rep;
}
instruction_ref program::replace_instruction(instruction_ref ins, instruction_ref start)
{
assert(ins != start);
return replace_instructions(ins, start, std::next(start));
}
instruction_ref program::remove_instruction(instruction_ref ins)
{
assert(has_instruction(ins));
......@@ -126,7 +129,7 @@ instruction_ref program::add_parameter(std::string name, shape s)
return impl->instructions.begin();
}
shape program::get_parameter_shape(std::string name)
shape program::get_parameter_shape(std::string name) const
{
auto ins = std::find_if(
impl->instructions.begin(), impl->instructions.end(), [&](const instruction& x) {
......@@ -167,8 +170,8 @@ bool program::has_instruction(instruction_ref ins) const
}) != impl->instructions.end();
}
instruction_ref program::begin() { return impl->instructions.begin(); }
instruction_ref program::end() { return impl->instructions.end(); }
instruction_ref program::begin() const { return impl->instructions.begin(); }
instruction_ref program::end() const { return impl->instructions.end(); }
shape program::get_shape() const { return impl->instructions.back().result; }
......
......@@ -112,6 +112,7 @@ void replace_test()
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.replace_instruction(sum, minus_op{}, two, one);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result == migraph::literal{1});
......@@ -127,6 +128,7 @@ void replace_ins_test()
auto sum = p.add_instruction(sum_op{}, one, two);
auto minus = p.add_instruction(minus_op{}, two, one);
p.replace_instruction(sum, minus);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result == migraph::literal{1});
......@@ -140,23 +142,10 @@ void replace_ins_test2()
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, two, one);
auto minus = p.add_instruction(minus_op{}, two, one);
p.add_instruction(pass_op{}, minus);
p.replace_instruction(two, sum);
auto result = p.eval({});
EXPECT(result == migraph::literal{2});
EXPECT(result != migraph::literal{3});
}
void replace_inss_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, two, one);
p.replace_instructions(two, two, std::next(sum));
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result == migraph::literal{2});
......@@ -174,6 +163,7 @@ void insert_replace_test()
auto sum0 = p.insert_instruction(sum1, sum_op{}, two, two);
p.replace_instruction(sum1, minus_op{}, sum0, two);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result == migraph::literal{4});
......@@ -228,7 +218,6 @@ int main()
replace_test();
replace_ins_test();
replace_ins_test2();
replace_inss_test();
insert_replace_test();
target_test();
reverse_target_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment