Commit 6812d12c authored by Paul's avatar Paul
Browse files

Add test for passes

parent f0c512bc
#ifndef MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
namespace migraph {
template <class T>
struct iterator_for_range {
T* base;
using base_iterator = decltype(base->begin());
struct iterator {
base_iterator i;
base_iterator operator * () { return i; }
base_iterator operator ++ () { return ++i; }
bool operator != (const iterator& rhs) { return i != rhs.i; }
};
iterator begin() { return {base->begin()}; }
iterator end() { return {base->end()}; }
};
template <class T>
iterator_for_range<T> iterator_for(T& x)
{
return {&x};
}
} // namespace migraph
#endif
......@@ -2,6 +2,8 @@
#include <migraph/program.hpp>
#include <migraph/argument.hpp>
#include <migraph/shape.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
#include <sstream>
#include "test.hpp"
......@@ -72,6 +74,49 @@ struct id_target
migraph::context get_context() const { return {}; }
};
struct reverse_pass
{
std::string name() const
{
return "reverse_pass";
}
void apply(migraph::program& p) const
{
for(auto ins:migraph::iterator_for(p))
{
if(ins->op.name() == "sum")
{
p.replace_instruction(ins, minus_op{}, ins->arguments);
}
else if(ins->op.name() == "minus")
{
p.replace_instruction(ins, sum_op{}, ins->arguments);
}
}
}
};
struct reverse_target
{
std::string name() const { return "reverse"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return { reverse_pass{} };
}
migraph::context get_context() const { return {}; }
};
struct double_reverse_target
{
std::string name() const { return "double_reverse"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return { reverse_pass{}, reverse_pass{} };
}
migraph::context get_context() const { return {}; }
};
void literal_test1()
{
migraph::program p;
......@@ -170,6 +215,32 @@ void target_test()
EXPECT(result != migraph::literal{4});
}
void reverse_target_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one);
p.compile(reverse_target{});
auto result = p.eval({});
EXPECT(result == migraph::literal{1});
EXPECT(result != migraph::literal{4});
}
void double_reverse_target_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one);
p.compile(double_reverse_target{});
auto result = p.eval({});
EXPECT(result == migraph::literal{3});
EXPECT(result != migraph::literal{4});
}
int main()
{
literal_test1();
......@@ -179,4 +250,5 @@ int main()
replace_test();
insert_replace_test();
target_test();
reverse_target_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment