Commit 7c9c11d7 authored by Paul's avatar Paul
Browse files

Add more tests for program

parent 71e1382d
...@@ -43,6 +43,19 @@ inline std::string join_strings(Strings strings, const std::string& delim) ...@@ -43,6 +43,19 @@ inline std::string join_strings(Strings strings, const std::string& delim)
}); });
} }
template<class F>
std::string trim(const std::string &s, F f)
{
auto start = std::find_if_not(s.begin(), s.end(), f);
auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base();
return std::string(start, last);
}
inline std::string trim(const std::string &s)
{
return trim(s, [](int c){ return std::isspace(c); });
}
template <class F> template <class F>
inline std::string transform_string(std::string s, F f) inline std::string transform_string(std::string s, F f)
{ {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/stringutils.hpp>
#include <sstream> #include <sstream>
#include "test.hpp" #include "test.hpp"
#include <basic_ops.hpp> #include <basic_ops.hpp>
...@@ -181,6 +182,24 @@ TEST_CASE(param_error_test) ...@@ -181,6 +182,24 @@ TEST_CASE(param_error_test)
"Parameter not found: y")); "Parameter not found: y"));
} }
TEST_CASE(param_error_shape_test)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 1}});
auto y = p.add_parameter("y", {migraphx::shape::int32_type, {1, 1}});
p.add_instruction(sum_op{}, x, y);
EXPECT(test::throws<migraphx::exception>(
[&] {
p.eval({
{"x", migraphx::literal{1}.get_argument()},
{"y", migraphx::literal{{migraphx::shape::int32_type, {1, 1}}, {2}}.get_argument()},
});
},
"Incorrect shape {int32_type, {1}, {0}} for parameter: x"));
}
TEST_CASE(get_param1) TEST_CASE(get_param1)
{ {
migraphx::program p; migraphx::program p;
...@@ -285,7 +304,7 @@ TEST_CASE(replace_op_recompute_shape_throw) ...@@ -285,7 +304,7 @@ TEST_CASE(replace_op_recompute_shape_throw)
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
EXPECT(test::throws([&] { sum->replace(unary_pass_op{}); })); EXPECT(test::throws<migraphx::exception>([&] { sum->replace(unary_pass_op{}); }));
} }
TEST_CASE(insert_replace_test) TEST_CASE(insert_replace_test)
...@@ -377,6 +396,16 @@ TEST_CASE(double_invert_target_test) ...@@ -377,6 +396,16 @@ TEST_CASE(double_invert_target_test)
EXPECT(result != migraphx::literal{4}); EXPECT(result != migraphx::literal{4});
} }
TEST_CASE(reverse_target_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
EXPECT(test::throws<migraphx::exception>([&]{ p.compile(reverse_target{}); }));
}
// Check that the program doesnt modify the context directly, and only the operators modify the // Check that the program doesnt modify the context directly, and only the operators modify the
// context // context
TEST_CASE(eval_context1) TEST_CASE(eval_context1)
...@@ -425,4 +454,49 @@ TEST_CASE(eval_context3) ...@@ -425,4 +454,49 @@ TEST_CASE(eval_context3)
EXPECT(not is_shared(t.ctx, p.get_context())); EXPECT(not is_shared(t.ctx, p.get_context()));
} }
struct cout_redirect {
cout_redirect()=delete;
cout_redirect(const cout_redirect&)=delete;
template<class T>
cout_redirect(T& stream)
: old(std::cout.rdbuf(stream.rdbuf()))
{}
~cout_redirect()
{
std::cout.rdbuf(old);
}
private:
std::streambuf * old;
};
template<class F>
std::string capture_output(F f)
{
std::stringstream ss;
cout_redirect cr{ss};
f();
return ss.str();
}
TEST_CASE(debug_print_test)
{
migraphx::program p;
auto one = p.add_literal(1);
migraphx::program p2;
auto one2 = p2.add_literal(1);
auto program_out = migraphx::trim(capture_output([&]{ p.debug_print(); }));
auto ins_out = migraphx::trim(capture_output([&]{ p.debug_print(one); }));
auto inss_out = migraphx::trim(capture_output([&]{ p.debug_print({one}); }));
auto end_out = migraphx::trim(capture_output([&]{ p.debug_print(p.end()); }));
auto p2_ins_out = migraphx::trim(capture_output([&]{ p.debug_print(one2); }));
EXPECT(program_out == ins_out);
EXPECT(inss_out == ins_out);
EXPECT(end_out == "End instruction");
EXPECT(p2_ins_out == "Instruction not part of program");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment