Commit d2d5fd19 authored by Paul's avatar Paul
Browse files

Fix test for replacements

parent 64370f87
......@@ -14,9 +14,7 @@ void auto_contiguous::apply(program& p) const
if(not s.standard())
{
auto c = p.insert_instruction(std::next(ins), contiguous{}, ins);
p.replace_instructions(ins, ins, std::next(c));
// auto prev = p.insert_instruction(ins, ins->op, ins->arguments);
// p.replace_instruction(ins, contiguous{}, prev);
p.replace_instruction(ins, c);
}
}
}
......
......@@ -55,6 +55,7 @@ struct instruction
void replace_argument(instruction_ref old, instruction_ref new_ins)
{
std::replace(arguments.begin(), arguments.end(), old, new_ins);
old->remove_output(*this);
recompute_shape();
}
......@@ -62,7 +63,7 @@ struct instruction
{
for(auto&& arg : arguments)
{
migraph::erase(arg->output, *this);
arg->remove_output(*this);
}
arguments.clear();
}
......@@ -73,6 +74,16 @@ struct instruction
}
bool valid(instruction_ref start) const
{
return valid() &&
std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->output.begin(), i->output.end(), *this);
return self != i->output.end() &&
std::distance(start, i) < std::distance(start, *self);
});
}
bool valid() const
{
shape computed;
if(op.name() == "@literal")
......@@ -100,12 +111,7 @@ struct instruction
[&](instruction_ref i) {
return std::find(i->arguments.begin(), i->arguments.end(), *this) !=
i->arguments.end();
}) &&
std::all_of(arguments.begin(), arguments.end(), [&](instruction_ref i) {
auto self = std::find(i->output.begin(), i->output.end(), *this);
return self != i->output.end() &&
std::distance(start, i) < std::distance(start, *self);
});
});
}
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
......@@ -114,6 +120,18 @@ struct instruction
friend bool operator!=(instruction_ref ref, const instruction& i) { return !(i == ref); }
void add_output(instruction_ref ins)
{
if(std::find(output.begin(), output.end(), ins) == output.end())
output.push_back(ins);
}
template<class T>
void remove_output(const T& ins)
{
migraph::erase(output, ins);
}
operation op;
shape result;
std::vector<instruction_ref> output;
......@@ -124,7 +142,7 @@ struct instruction
inline void backreference(instruction_ref ref)
{
for(auto&& arg : ref->arguments)
arg->output.push_back(ref);
arg->add_output(ref);
}
// TODO: Move to a cpp file
......
......@@ -55,6 +55,9 @@ struct program
instruction_ref
replace_instructions(instruction_ref ins, instruction_ref start, instruction_ref last);
instruction_ref
replace_instruction(instruction_ref ins, instruction_ref start);
instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
......
......@@ -38,6 +38,7 @@ program::insert_instruction(instruction_ref ins, operation op, std::vector<instr
auto result = impl->instructions.insert(ins, {op, r, args});
backreference(result);
assert(result->arguments == args);
assert(result->valid(begin()));
return result;
}
......@@ -52,6 +53,7 @@ program::replace_instruction(instruction_ref ins, operation op, std::vector<inst
shape r = compute_shape(op, args);
ins->replace(op, r, args);
backreference(ins);
assert(ins->valid(begin()));
return ins;
}
......@@ -61,16 +63,25 @@ program::replace_instructions(instruction_ref ins, instruction_ref start, instru
auto rep = std::prev(last);
for(auto&& out : ins->output)
{
if(std::find(start, last, out) == last)
{
out->replace_argument(ins, rep);
backreference(out);
}
assert(out->valid(begin()));
}
assert(rep->valid(begin()));
assert(ins->valid(begin()));
if(ins->output.empty())
return remove_instruction(ins);
return ins;
remove_instruction(ins);
return rep;
}
instruction_ref
program::replace_instruction(instruction_ref ins, instruction_ref start)
{
assert(ins != start);
return replace_instructions(ins, start, std::next(start));
}
instruction_ref program::remove_instruction(instruction_ref ins)
......@@ -182,7 +193,7 @@ void program::compile(const target& t)
{
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " +
std::to_string(index));
std::to_string(index) + ": " + invalid->op.name());
}
#endif
}
......
......@@ -27,7 +27,7 @@ migraph::literal get_2() { return migraph::literal{{migraph::shape::float_type,
migraph::literal get_2_broadcasted()
{
return migraph::literal{{migraph::shape::float_type, {2}, {1, 0}}, {1, 2}};
return migraph::literal{{migraph::shape::float_type, {2, 1}, {1, 0}}, {1, 2}};
}
void literal_broadcast()
......@@ -116,7 +116,7 @@ void after_param_broadcast()
int main()
{
literal_broadcast();
// literal_broadcast();
literal_transpose();
after_literal_transpose();
after_literal_broadcast();
......
......@@ -118,6 +118,51 @@ void replace_test()
EXPECT(result != migraph::literal{3});
}
void replace_ins_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto minus = p.add_instruction(minus_op{}, two, one);
p.replace_instruction(sum, minus);
auto result = p.eval({});
EXPECT(result == migraph::literal{1});
EXPECT(result != migraph::literal{3});
}
void replace_ins_test2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, two, one);
p.replace_instruction(two, sum);
auto result = p.eval({});
EXPECT(result == migraph::literal{2});
EXPECT(result != migraph::literal{3});
}
void replace_inss_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(minus_op{}, two, one);
p.replace_instructions(two, two, std::next(sum));
auto result = p.eval({});
EXPECT(result == migraph::literal{2});
EXPECT(result != migraph::literal{3});
}
void insert_replace_test()
{
migraph::program p;
......@@ -181,6 +226,9 @@ int main()
print_test();
param_test();
replace_test();
replace_ins_test();
replace_ins_test2();
replace_inss_test();
insert_replace_test();
target_test();
reverse_target_test();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment