Commit cf5baca1 authored by Paul's avatar Paul
Browse files

Add tests for validate

parent dd465fab
......@@ -19,7 +19,7 @@ struct outline
{
shape s;
std::string name() const { return "@outline"; }
shape compute_shape(std::vector<shape>) const { MIGRAPH_THROW("builtin"); }
shape compute_shape(std::vector<shape>) const { return s; }
argument compute(context&, shape, std::vector<argument>) const { MIGRAPH_THROW("builtin"); }
};
......
......@@ -64,18 +64,24 @@ struct instruction
bool valid(instruction_ref start) const
{
std::vector<shape> shapes(arguments.size());
std::transform(arguments.begin(), arguments.end(), shapes.begin(), [](instruction_ref ins) {
return ins->result;
});
shape computed;
try
if(op.name() == "@literal")
{
computed = op.compute_shape(shapes);
computed = lit.get_shape();
}
catch(migraph::exception&)
else if(op.name() == "@param")
{
return false;
computed = result;
}
else {
try
{
computed = compute_shape(op, arguments);
}
catch(migraph::exception&)
{
return false;
}
}
return result == computed &&
std::all_of(output.begin(),
......
......@@ -129,28 +129,34 @@ instruction_ref program::validate() const
{
return std::find_if(impl->instructions.begin(),
impl->instructions.end(),
[&](const instruction& i) { return i.valid(impl->instructions.begin()); });
[&](const instruction& i) { return !i.valid(impl->instructions.begin()); });
}
void program::compile(const target& t)
{
assert(this->validate() != impl->instructions.end());
assert(this->validate() == impl->instructions.end());
this->impl->ctx = t.get_context();
for(auto&& p : t.get_passes(this->impl->ctx))
{
p.apply(*this);
#ifndef NDEBUG
if(this->validate() == impl->instructions.end())
MIGRAPH_THROW(p.name() + " pass produces invalid program");
auto invalid = this->validate();
if(invalid != impl->instructions.end()) {
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW(p.name() + " pass produces invalid program at instruction " + std::to_string(index));
}
#endif
}
if(this->validate() == impl->instructions.end())
MIGRAPH_THROW("Invalid program from compilation");
auto invalid = this->validate();
if(invalid != impl->instructions.end()) {
auto index = std::distance(impl->instructions.begin(), invalid);
MIGRAPH_THROW("Invalid program from compilation at instruction " + std::to_string(index));
}
}
argument program::eval(std::unordered_map<std::string, argument> params) const
{
assert(this->validate() != impl->instructions.end());
assert(this->validate() == impl->instructions.end());
std::unordered_map<const instruction*, argument> results;
argument result;
for(auto& ins : impl->instructions)
......
#include <migraph/program.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
void simple_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
std::cout << std::distance(p.begin(), p.validate()) << std::endl;
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result == migraph::literal{3});
EXPECT(result != migraph::literal{4});
}
void out_of_order()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ins = p.add_instruction(sum_op{}, one, two);
p.move_instruction(two, p.end());
EXPECT(bool{p.validate() == ins});
}
int main() {
simple_test();
out_of_order();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment