Commit cf5baca1 authored by Paul's avatar Paul
Browse files

Add tests for validate

parent dd465fab
...@@ -19,7 +19,7 @@ struct outline ...@@ -19,7 +19,7 @@ struct outline
{ {
shape s; shape s;
std::string name() const { return "@outline"; } std::string name() const { return "@outline"; }
shape compute_shape(std::vector<shape>) const { MIGRAPH_THROW("builtin"); } shape compute_shape(std::vector<shape>) const { return s; }
argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); } argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); }
}; };
......
...@@ -64,18 +64,24 @@ struct instruction ...@@ -64,18 +64,24 @@ struct instruction
bool valid(instruction_ref start) const bool valid(instruction_ref start) const
{ {
std::vector<shape> shapes(arguments.size());
std::transform(arguments.begin(), arguments.end(), shapes.begin(), [](instruction_ref ins) {
return ins->result;
});
shape computed; shape computed;
try if(op.name() == "@literal")
{ {
computed = op.compute_shape(shapes); computed = lit.get_shape();
} }
catch(migraph::exception&) else if(op.name() == "@param")
{ {
return false; computed = result;
}
else {
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
} }
return result == computed && return result == computed &&
std::all_of(output.begin(), std::all_of(output.begin(),
......
...@@ -129,28 +129,34 @@ instruction_ref program::validate() const ...@@ -129,28 +129,34 @@ instruction_ref program::validate() const
{ {
return std::find_if(impl->instructions.begin(), return std::find_if(impl->instructions.begin(),
impl->instructions.end(), impl->instructions.end(),
[&](const instruction& i) { return i.valid(impl->instructions.begin()); }); [&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
} }
void program::compile(const target& t) void program::compile(const target& t)
{ {
assert(this->validate() != impl->instructions.end()); assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context(); this->impl->ctx = t.get_context();
for(auto&& p : t.get_passes(this->impl->ctx)) for(auto&& p : t.get_passes(this->impl->ctx))
{ {
p.apply(*this); p.apply(*this);
#ifndef NDEBUG #ifndef NDEBUG
if(this->validate() == impl->instructions.end()) auto invalid = this->validate();
MIGRAPH_THROW(p.name() + " pass produces invalid program"); if(invalid != impl->instructions.end()) {
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " + std::to_string(index));
}
#endif #endif
} }
if(this->validate() == impl->instructions.end()) auto invalid = this->validate();
MIGRAPH_THROW("Invalid program from compilation"); if(invalid != impl->instructions.end()) {
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW("Invalid program from compilation at instruction " + std::to_string(index));
}
} }
argument program::eval(std::unordered_map<std::string, argument> params) const argument program::eval(std::unordered_map<std::string, argument> params) const
{ {
assert(this->validate() != impl->instructions.end()); assert(this->validate() == impl->instructions.end());
std::unordered_map<const instruction*, argument> results; std::unordered_map<const instruction*, argument> results;
argument result; argument result;
for(auto& ins : impl->instructions) for(auto& ins : impl->instructions)
......
#include <migraph/program.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
void simple_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
std::cout << std::distance(p.begin(), p.validate()) << std::endl;
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result == migraph::literal{3});
EXPECT(result != migraph::literal{4});
}
void out_of_order()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ins = p.add_instruction(sum_op{}, one, two);
p.move_instruction(two, p.end());
EXPECT(bool{p.validate() == ins});
}
int main() {
simple_test();
out_of_order();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment