Commit 414e2fac authored by Paul's avatar Paul
Browse files

Merge branch 'test-driver'

parents bf571e25 0b7469d3
...@@ -14,6 +14,7 @@ struct contiguous_target ...@@ -14,6 +14,7 @@ struct contiguous_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
// TODO: Add this test case
void literal_broadcast() void literal_broadcast()
{ {
migraph::program p; migraph::program p;
...@@ -25,7 +26,7 @@ void literal_broadcast() ...@@ -25,7 +26,7 @@ void literal_broadcast()
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
} }
void literal_transpose() TEST_CASE(literal_transpose)
{ {
migraph::program p; migraph::program p;
p.add_literal(get_2x2_transposed()); p.add_literal(get_2x2_transposed());
...@@ -36,7 +37,7 @@ void literal_transpose() ...@@ -36,7 +37,7 @@ void literal_transpose()
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
} }
void after_literal_transpose() TEST_CASE(after_literal_transpose)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
...@@ -51,7 +52,7 @@ void after_literal_transpose() ...@@ -51,7 +52,7 @@ void after_literal_transpose()
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
} }
void after_literal_broadcast() TEST_CASE(after_literal_broadcast)
{ {
migraph::program p; migraph::program p;
auto l1 = p.add_literal(get_2x2()); auto l1 = p.add_literal(get_2x2());
...@@ -67,7 +68,7 @@ void after_literal_broadcast() ...@@ -67,7 +68,7 @@ void after_literal_broadcast()
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
} }
void after_param_transpose() TEST_CASE(after_param_transpose)
{ {
migraph::program p; migraph::program p;
auto l = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}}); auto l = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}});
...@@ -82,7 +83,7 @@ void after_param_transpose() ...@@ -82,7 +83,7 @@ void after_param_transpose()
EXPECT(not p.get_shape().transposed()); EXPECT(not p.get_shape().transposed());
} }
void after_param_broadcast() TEST_CASE(after_param_broadcast)
{ {
migraph::program p; migraph::program p;
auto l1 = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}}); auto l1 = p.add_parameter("2x2", {migraph::shape::float_type, {2, 2}});
...@@ -98,12 +99,4 @@ void after_param_broadcast() ...@@ -98,12 +99,4 @@ void after_param_broadcast()
EXPECT(not p.get_shape().broadcasted()); EXPECT(not p.get_shape().broadcasted());
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
// literal_broadcast();
literal_transpose();
after_literal_transpose();
after_literal_broadcast();
after_param_transpose();
after_param_broadcast();
}
...@@ -14,7 +14,7 @@ struct cse_target ...@@ -14,7 +14,7 @@ struct cse_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
void cse_test1() TEST_CASE(cse_test1)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -38,7 +38,7 @@ void cse_test1() ...@@ -38,7 +38,7 @@ void cse_test1()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
void cse_test2() TEST_CASE(cse_test2)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -63,7 +63,7 @@ void cse_test2() ...@@ -63,7 +63,7 @@ void cse_test2()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
void cse_test3() TEST_CASE(cse_test3)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -86,7 +86,7 @@ void cse_test3() ...@@ -86,7 +86,7 @@ void cse_test3()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
void cse_test4() TEST_CASE(cse_test4)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -112,10 +112,4 @@ void cse_test4() ...@@ -112,10 +112,4 @@ void cse_test4()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
cse_test1();
cse_test2();
cse_test3();
cse_test4();
}
...@@ -14,7 +14,7 @@ struct const_prop_target ...@@ -14,7 +14,7 @@ struct const_prop_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
void const_add1() TEST_CASE(const_add1)
{ {
migraph::program p1; migraph::program p1;
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
...@@ -29,7 +29,7 @@ void const_add1() ...@@ -29,7 +29,7 @@ void const_add1()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
void const_add2() TEST_CASE(const_add2)
{ {
migraph::program p1; migraph::program p1;
auto one = p1.add_parameter("one", {migraph::shape::int32_type, {1}}); auto one = p1.add_parameter("one", {migraph::shape::int32_type, {1}});
...@@ -44,7 +44,7 @@ void const_add2() ...@@ -44,7 +44,7 @@ void const_add2()
EXPECT(p1 != p2); EXPECT(p1 != p2);
} }
void const_add3() TEST_CASE(const_add3)
{ {
migraph::program p1; migraph::program p1;
auto one = p1.add_literal(1); auto one = p1.add_literal(1);
...@@ -60,9 +60,4 @@ void const_add3() ...@@ -60,9 +60,4 @@ void const_add3()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
const_add1();
const_add2();
const_add3();
}
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <migraph/verify.hpp> #include <migraph/verify.hpp>
#include "test.hpp" #include "test.hpp"
void slice_test() TEST_CASE(slice_test)
{ {
{ {
migraph::program p; migraph::program p;
...@@ -47,7 +47,7 @@ void slice_test() ...@@ -47,7 +47,7 @@ void slice_test()
} }
} }
void concat_test() TEST_CASE(concat_test)
{ {
{ {
migraph::program p; migraph::program p;
...@@ -97,7 +97,7 @@ void concat_test() ...@@ -97,7 +97,7 @@ void concat_test()
} }
} }
void squeeze_test() TEST_CASE(squeeze_test)
{ {
{ {
migraph::program p; migraph::program p;
...@@ -134,7 +134,7 @@ void squeeze_test() ...@@ -134,7 +134,7 @@ void squeeze_test()
} }
} }
void unsqueeze_test() TEST_CASE(unsqueeze_test)
{ {
{ {
migraph::program p; migraph::program p;
...@@ -160,7 +160,7 @@ void unsqueeze_test() ...@@ -160,7 +160,7 @@ void unsqueeze_test()
} }
} }
void globalavgpool_test() TEST_CASE(globalavgpool_test)
{ {
migraph::program p; migraph::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}}; auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}};
...@@ -180,7 +180,7 @@ void globalavgpool_test() ...@@ -180,7 +180,7 @@ void globalavgpool_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void globalmaxpool_test() TEST_CASE(globalmaxpool_test)
{ {
migraph::program p; migraph::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}}; auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}};
...@@ -200,7 +200,7 @@ void globalmaxpool_test() ...@@ -200,7 +200,7 @@ void globalmaxpool_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void im2col_3x3_no_pad_identity_test() TEST_CASE(im2col_3x3_no_pad_identity_test)
{ {
std::size_t f[2] = {3, 3}; std::size_t f[2] = {3, 3};
std::size_t size[2] = {3, 3}; std::size_t size[2] = {3, 3};
...@@ -229,7 +229,7 @@ void im2col_3x3_no_pad_identity_test() ...@@ -229,7 +229,7 @@ void im2col_3x3_no_pad_identity_test()
EXPECT(migraph::verify_range(results_vector, input)); EXPECT(migraph::verify_range(results_vector, input));
} }
void im2col_3x3_no_pad_test() TEST_CASE(im2col_3x3_no_pad_test)
{ {
std::size_t f[2] = {3, 3}; std::size_t f[2] = {3, 3};
std::size_t size[2] = {4, 4}; std::size_t size[2] = {4, 4};
...@@ -261,7 +261,7 @@ void im2col_3x3_no_pad_test() ...@@ -261,7 +261,7 @@ void im2col_3x3_no_pad_test()
EXPECT(migraph::verify_range(results_vector, correct)); EXPECT(migraph::verify_range(results_vector, correct));
} }
void im2col_3x3_stride_2_no_pad_test() TEST_CASE(im2col_3x3_stride_2_no_pad_test)
{ {
std::size_t f[2] = {3, 3}; std::size_t f[2] = {3, 3};
std::size_t size[2] = {6, 6}; std::size_t size[2] = {6, 6};
...@@ -294,7 +294,7 @@ void im2col_3x3_stride_2_no_pad_test() ...@@ -294,7 +294,7 @@ void im2col_3x3_stride_2_no_pad_test()
EXPECT(migraph::verify_range(results_vector, correct)); EXPECT(migraph::verify_range(results_vector, correct));
} }
void im2col_3x3_with_padding_test() TEST_CASE(im2col_3x3_with_padding_test)
{ {
std::size_t f[2] = {3, 3}; std::size_t f[2] = {3, 3};
std::size_t size[2] = {2, 2}; std::size_t size[2] = {2, 2};
...@@ -326,7 +326,7 @@ void im2col_3x3_with_padding_test() ...@@ -326,7 +326,7 @@ void im2col_3x3_with_padding_test()
EXPECT(migraph::verify_range(results_vector, correct)); EXPECT(migraph::verify_range(results_vector, correct));
} }
void batch_norm_inference_test() TEST_CASE(batch_norm_inference_test)
{ {
migraph::program p; migraph::program p;
const size_t width = 2, height = 2, channels = 4, batches = 2; const size_t width = 2, height = 2, channels = 4, batches = 2;
...@@ -366,7 +366,7 @@ void batch_norm_inference_test() ...@@ -366,7 +366,7 @@ void batch_norm_inference_test()
EXPECT(migraph::verify_range(result_vector, gold)); EXPECT(migraph::verify_range(result_vector, gold));
} }
void im2col_3x3_with_channels_identity_test() TEST_CASE(im2col_3x3_with_channels_identity_test)
{ {
std::size_t f[2] = {3, 3}; std::size_t f[2] = {3, 3};
std::size_t size[2] = {3, 3}; std::size_t size[2] = {3, 3};
...@@ -395,7 +395,7 @@ void im2col_3x3_with_channels_identity_test() ...@@ -395,7 +395,7 @@ void im2col_3x3_with_channels_identity_test()
EXPECT(migraph::verify_range(results_vector, input)); EXPECT(migraph::verify_range(results_vector, input));
} }
void exp_test() TEST_CASE(exp_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -409,7 +409,7 @@ void exp_test() ...@@ -409,7 +409,7 @@ void exp_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void sin_test() TEST_CASE(sin_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -423,7 +423,7 @@ void sin_test() ...@@ -423,7 +423,7 @@ void sin_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void cos_test() TEST_CASE(cos_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -437,7 +437,7 @@ void cos_test() ...@@ -437,7 +437,7 @@ void cos_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void tan_test() TEST_CASE(tan_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -451,7 +451,7 @@ void tan_test() ...@@ -451,7 +451,7 @@ void tan_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void add_test() TEST_CASE(add_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -466,7 +466,7 @@ void add_test() ...@@ -466,7 +466,7 @@ void add_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void broadcast_test() TEST_CASE(broadcast_test)
{ {
migraph::program p; migraph::program p;
migraph::shape a_shape{migraph::shape::int32_type, {2, 2}}; migraph::shape a_shape{migraph::shape::int32_type, {2, 2}};
...@@ -485,7 +485,7 @@ void broadcast_test() ...@@ -485,7 +485,7 @@ void broadcast_test()
EXPECT(output(1, 0) == -3); EXPECT(output(1, 0) == -3);
EXPECT(output(1, 1) == -3); EXPECT(output(1, 1) == -3);
} }
void add_broadcast_test() TEST_CASE(add_broadcast_test)
{ {
migraph::program p; migraph::program p;
migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}}; migraph::shape a_shape{migraph::shape::float_type, {2, 2, 3}};
...@@ -506,7 +506,7 @@ void add_broadcast_test() ...@@ -506,7 +506,7 @@ void add_broadcast_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void sub_test() TEST_CASE(sub_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -521,7 +521,7 @@ void sub_test() ...@@ -521,7 +521,7 @@ void sub_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void mul_test() TEST_CASE(mul_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -536,7 +536,7 @@ void mul_test() ...@@ -536,7 +536,7 @@ void mul_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void div_test() TEST_CASE(div_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -551,7 +551,7 @@ void div_test() ...@@ -551,7 +551,7 @@ void div_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void relu_test() TEST_CASE(relu_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -565,7 +565,7 @@ void relu_test() ...@@ -565,7 +565,7 @@ void relu_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void leaky_relu_test() TEST_CASE(leaky_relu_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {3}}; migraph::shape s{migraph::shape::float_type, {3}};
...@@ -579,7 +579,7 @@ void leaky_relu_test() ...@@ -579,7 +579,7 @@ void leaky_relu_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void imagescaler_test() TEST_CASE(imagescaler_test)
{ {
migraph::program p; migraph::program p;
migraph::shape s{migraph::shape::float_type, {1, 3, 2, 2}}; migraph::shape s{migraph::shape::float_type, {1, 3, 2, 2}};
...@@ -626,7 +626,7 @@ void imagescaler_test() ...@@ -626,7 +626,7 @@ void imagescaler_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
void reshape_test() TEST_CASE(reshape_test)
{ {
migraph::shape a_shape{migraph::shape::float_type, {24, 1, 1, 1}}; migraph::shape a_shape{migraph::shape::float_type, {24, 1, 1, 1}};
std::vector<float> data(24); std::vector<float> data(24);
...@@ -716,8 +716,10 @@ void gemm_test() ...@@ -716,8 +716,10 @@ void gemm_test()
EXPECT(std::abs(results_vector[i] - c[i]) < tol); EXPECT(std::abs(results_vector[i] - c[i]) < tol);
} }
} }
TEST_CASE_REGISTER(gemm_test<float>)
TEST_CASE_REGISTER(gemm_test<double>)
void maxpool_test() TEST_CASE(maxpool_test)
{ {
migraph::program p; migraph::program p;
std::vector<float> a = { std::vector<float> a = {
...@@ -763,7 +765,7 @@ void maxpool_test() ...@@ -763,7 +765,7 @@ void maxpool_test()
p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al); p.add_instruction(migraph::op::pooling{"max", {{0, 0}}, {{2, 2}}, {{3, 2}}}, al);
p.compile(migraph::cpu::target{}); p.compile(migraph::cpu::target{});
auto result = p.eval({}); auto result = p.eval({});
std::cout << result.get_shape() << std::endl; // std::cout << result.get_shape() << std::endl;
std::vector<float> results_vector(36); std::vector<float> results_vector(36);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
float tol = 1e-6; float tol = 1e-6;
...@@ -774,7 +776,7 @@ void maxpool_test() ...@@ -774,7 +776,7 @@ void maxpool_test()
} }
} }
void softmax_test() TEST_CASE(softmax_test)
{ {
migraph::program p; migraph::program p;
std::vector<float> a = { std::vector<float> a = {
...@@ -833,7 +835,7 @@ void softmax_test() ...@@ -833,7 +835,7 @@ void softmax_test()
EXPECT(migraph::verify_range(results_vector, s)); EXPECT(migraph::verify_range(results_vector, s));
} }
void conv2d_test() TEST_CASE(conv2d_test)
{ {
migraph::program p; migraph::program p;
std::vector<float> a = { std::vector<float> a = {
...@@ -896,7 +898,7 @@ void conv2d_test() ...@@ -896,7 +898,7 @@ void conv2d_test()
EXPECT(migraph::verify_range(results_vector, s)); EXPECT(migraph::verify_range(results_vector, s));
} }
void conv2d_padding_test() TEST_CASE(conv2d_padding_test)
{ {
migraph::program p; migraph::program p;
std::vector<float> a = { std::vector<float> a = {
...@@ -952,7 +954,7 @@ void conv2d_padding_test() ...@@ -952,7 +954,7 @@ void conv2d_padding_test()
EXPECT(migraph::verify_range(results_vector, s)); EXPECT(migraph::verify_range(results_vector, s));
} }
void conv2d_padding_stride_test() TEST_CASE(conv2d_padding_stride_test)
{ {
migraph::program p; migraph::program p;
std::vector<float> a = { std::vector<float> a = {
...@@ -1013,7 +1015,7 @@ void conv2d_padding_stride_test() ...@@ -1013,7 +1015,7 @@ void conv2d_padding_stride_test()
EXPECT(migraph::verify_range(results_vector, s)); EXPECT(migraph::verify_range(results_vector, s));
} }
void transpose_test() TEST_CASE(transpose_test)
{ {
migraph::shape a_shape{migraph::shape::float_type, {1, 2, 2, 3}}; migraph::shape a_shape{migraph::shape::float_type, {1, 2, 2, 3}};
std::vector<float> data(12); std::vector<float> data(12);
...@@ -1048,7 +1050,7 @@ void transpose_test() ...@@ -1048,7 +1050,7 @@ void transpose_test()
} }
} }
void contiguous_test() TEST_CASE(contiguous_test)
{ {
migraph::shape a_shape{migraph::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}}; migraph::shape a_shape{migraph::shape::float_type, {1, 3, 2, 2}, {12, 1, 6, 3}};
std::vector<float> data(12); std::vector<float> data(12);
...@@ -1068,41 +1070,4 @@ void contiguous_test() ...@@ -1068,41 +1070,4 @@ void contiguous_test()
EXPECT(migraph::verify_range(results_vector, gold)); EXPECT(migraph::verify_range(results_vector, gold));
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
concat_test();
slice_test();
squeeze_test();
unsqueeze_test();
exp_test();
sin_test();
cos_test();
tan_test();
add_test();
broadcast_test();
add_broadcast_test();
imagescaler_test();
sub_test();
mul_test();
div_test();
relu_test();
leaky_relu_test();
gemm_test<float>();
gemm_test<double>();
reshape_test();
transpose_test();
// contiguous_test();
softmax_test();
// maxpool_test();
conv2d_test();
conv2d_padding_test();
conv2d_padding_stride_test();
batch_norm_inference_test();
globalavgpool_test();
globalmaxpool_test();
im2col_3x3_no_pad_identity_test();
im2col_3x3_no_pad_test();
im2col_3x3_stride_2_no_pad_test();
im2col_3x3_with_channels_identity_test();
im2col_3x3_with_padding_test();
}
...@@ -12,7 +12,7 @@ struct dce_target ...@@ -12,7 +12,7 @@ struct dce_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
void simple_test() TEST_CASE(simple_test)
{ {
migraph::program p; migraph::program p;
...@@ -27,7 +27,7 @@ void simple_test() ...@@ -27,7 +27,7 @@ void simple_test()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void simple_test_nop() TEST_CASE(simple_test_nop)
{ {
migraph::program p; migraph::program p;
...@@ -43,7 +43,7 @@ void simple_test_nop() ...@@ -43,7 +43,7 @@ void simple_test_nop()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void simple_test_nop2() TEST_CASE(simple_test_nop2)
{ {
migraph::program p; migraph::program p;
...@@ -59,7 +59,7 @@ void simple_test_nop2() ...@@ -59,7 +59,7 @@ void simple_test_nop2()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void duplicate_test1() TEST_CASE(duplicate_test1)
{ {
migraph::program p; migraph::program p;
...@@ -75,7 +75,7 @@ void duplicate_test1() ...@@ -75,7 +75,7 @@ void duplicate_test1()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void duplicate_test2() TEST_CASE(duplicate_test2)
{ {
migraph::program p; migraph::program p;
...@@ -92,7 +92,7 @@ void duplicate_test2() ...@@ -92,7 +92,7 @@ void duplicate_test2()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void depth_test() TEST_CASE(depth_test)
{ {
migraph::program p; migraph::program p;
...@@ -111,12 +111,4 @@ void depth_test() ...@@ -111,12 +111,4 @@ void depth_test()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
simple_test();
simple_test_nop();
simple_test_nop2();
duplicate_test1();
duplicate_test2();
depth_test();
}
...@@ -32,7 +32,7 @@ struct allocate ...@@ -32,7 +32,7 @@ struct allocate
} }
}; };
void basic() TEST_CASE(basic)
{ {
migraph::program p; migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {8}}}); auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {8}}});
...@@ -49,7 +49,7 @@ void basic() ...@@ -49,7 +49,7 @@ void basic()
EXPECT(p.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4)); EXPECT(p.get_parameter_shape("memory").bytes() == (8 * 4 + 40 * 4 + 200 * 4));
} }
void aligned() TEST_CASE(aligned)
{ {
migraph::program p; migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}}); auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}});
...@@ -66,7 +66,7 @@ void aligned() ...@@ -66,7 +66,7 @@ void aligned()
EXPECT(p.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4)); EXPECT(p.get_parameter_shape("memory").bytes() == (32 + 32 + 200 * 4));
} }
void unaligned() TEST_CASE(unaligned)
{ {
migraph::program p; migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}}); auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}});
...@@ -83,7 +83,7 @@ void unaligned() ...@@ -83,7 +83,7 @@ void unaligned()
EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
} }
void float_aligned() TEST_CASE(float_aligned)
{ {
migraph::program p; migraph::program p;
auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}}); auto a1 = p.add_instruction(allocate{migraph::shape{migraph::shape::float_type, {1}}});
...@@ -100,11 +100,8 @@ void float_aligned() ...@@ -100,11 +100,8 @@ void float_aligned()
EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4)); EXPECT(p.get_parameter_shape("memory").bytes() == (1 * 4 + 2 * 4 + 200 * 4));
} }
int main() int main(int argc, const char* argv[])
{ {
setenv("MIGRAPH_DISABLE_MEMORY_COLORING", "1", 1); setenv("MIGRAPH_DISABLE_MEMORY_COLORING", "1", 1);
basic(); test::run(argc, argv);
aligned();
unaligned();
float_aligned();
} }
...@@ -79,7 +79,7 @@ struct fred_op ...@@ -79,7 +79,7 @@ struct fred_op
} }
}; };
void basic() TEST_CASE(basic)
{ {
auto create_test_program = []() { auto create_test_program = []() {
migraph::program p; migraph::program p;
...@@ -123,7 +123,7 @@ void basic() ...@@ -123,7 +123,7 @@ void basic()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
void wont_work() TEST_CASE(wont_work)
{ {
auto create_test_program = []() { auto create_test_program = []() {
migraph::program p; migraph::program p;
...@@ -167,8 +167,4 @@ void wont_work() ...@@ -167,8 +167,4 @@ void wont_work()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
basic();
wont_work();
}
...@@ -14,7 +14,7 @@ struct eliminate_contiguous_target ...@@ -14,7 +14,7 @@ struct eliminate_contiguous_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
void standard_op() TEST_CASE(standard_op)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
...@@ -26,7 +26,7 @@ void standard_op() ...@@ -26,7 +26,7 @@ void standard_op()
EXPECT(std::distance(p.begin(), p.end()) == count); EXPECT(std::distance(p.begin(), p.end()) == count);
} }
void non_standard_op() TEST_CASE(non_standard_op)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(get_2x2()); auto l = p.add_literal(get_2x2());
...@@ -38,8 +38,4 @@ void non_standard_op() ...@@ -38,8 +38,4 @@ void non_standard_op()
EXPECT(std::distance(p.begin(), p.end()) == (count - 1)); EXPECT(std::distance(p.begin(), p.end()) == (count - 1));
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
standard_op();
non_standard_op();
}
...@@ -50,7 +50,7 @@ struct double_reverse_target ...@@ -50,7 +50,7 @@ struct double_reverse_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
void literal_test1() TEST_CASE(literal_test1)
{ {
migraph::program p; migraph::program p;
...@@ -62,7 +62,7 @@ void literal_test1() ...@@ -62,7 +62,7 @@ void literal_test1()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void literal_test2() TEST_CASE(literal_test2)
{ {
migraph::program p; migraph::program p;
...@@ -76,7 +76,7 @@ void literal_test2() ...@@ -76,7 +76,7 @@ void literal_test2()
EXPECT(result != migraph::literal{3}); EXPECT(result != migraph::literal{3});
} }
void print_test() TEST_CASE(print_test)
{ {
migraph::program p; migraph::program p;
...@@ -90,7 +90,7 @@ void print_test() ...@@ -90,7 +90,7 @@ void print_test()
EXPECT(!s.empty()); EXPECT(!s.empty());
} }
void param_test() TEST_CASE(param_test)
{ {
migraph::program p; migraph::program p;
...@@ -104,7 +104,7 @@ void param_test() ...@@ -104,7 +104,7 @@ void param_test()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void param_error_test() TEST_CASE(param_error_test)
{ {
migraph::program p; migraph::program p;
...@@ -119,7 +119,7 @@ void param_error_test() ...@@ -119,7 +119,7 @@ void param_error_test()
"Parameter not found: y")); "Parameter not found: y"));
} }
void replace_test() TEST_CASE(replace_test)
{ {
migraph::program p; migraph::program p;
...@@ -134,7 +134,7 @@ void replace_test() ...@@ -134,7 +134,7 @@ void replace_test()
EXPECT(result != migraph::literal{3}); EXPECT(result != migraph::literal{3});
} }
void replace_ins_test() TEST_CASE(replace_ins_test)
{ {
migraph::program p; migraph::program p;
...@@ -150,7 +150,7 @@ void replace_ins_test() ...@@ -150,7 +150,7 @@ void replace_ins_test()
EXPECT(result != migraph::literal{3}); EXPECT(result != migraph::literal{3});
} }
void replace_ins_test2() TEST_CASE(replace_ins_test2)
{ {
migraph::program p; migraph::program p;
...@@ -167,7 +167,7 @@ void replace_ins_test2() ...@@ -167,7 +167,7 @@ void replace_ins_test2()
EXPECT(result != migraph::literal{3}); EXPECT(result != migraph::literal{3});
} }
void insert_replace_test() TEST_CASE(insert_replace_test)
{ {
migraph::program p; migraph::program p;
...@@ -185,7 +185,7 @@ void insert_replace_test() ...@@ -185,7 +185,7 @@ void insert_replace_test()
EXPECT(result != migraph::literal{5}); EXPECT(result != migraph::literal{5});
} }
void target_test() TEST_CASE(target_test)
{ {
migraph::program p; migraph::program p;
...@@ -198,7 +198,7 @@ void target_test() ...@@ -198,7 +198,7 @@ void target_test()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void reverse_target_test() TEST_CASE(reverse_target_test)
{ {
migraph::program p; migraph::program p;
...@@ -211,7 +211,7 @@ void reverse_target_test() ...@@ -211,7 +211,7 @@ void reverse_target_test()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
void double_reverse_target_test() TEST_CASE(double_reverse_target_test)
{ {
migraph::program p; migraph::program p;
...@@ -224,17 +224,4 @@ void double_reverse_target_test() ...@@ -224,17 +224,4 @@ void double_reverse_target_test()
EXPECT(result != migraph::literal{4}); EXPECT(result != migraph::literal{4});
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
literal_test1();
literal_test2();
print_test();
param_test();
param_error_test();
replace_test();
replace_ins_test();
replace_ins_test2();
insert_replace_test();
target_test();
reverse_target_test();
}
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <test.hpp> #include <test.hpp>
#include <migraph/verify.hpp> #include <migraph/verify.hpp>
void fwd_conv_batchnorm_rewrite_test() TEST_CASE(fwd_conv_batchnorm_rewrite_test)
{ {
std::vector<float> xdata = { std::vector<float> xdata = {
0.26485917, 0.61703885, 0.32762103, 0.2503367, 0.6552712, 0.07947932, 0.95442678, 0.26485917, 0.61703885, 0.32762103, 0.2503367, 0.6552712, 0.07947932, 0.95442678,
...@@ -64,8 +64,4 @@ void fwd_conv_batchnorm_rewrite_test() ...@@ -64,8 +64,4 @@ void fwd_conv_batchnorm_rewrite_test()
EXPECT(migraph::verify_range(results_vector1, results_vector2)); EXPECT(migraph::verify_range(results_vector1, results_vector2));
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
fwd_conv_batchnorm_rewrite_test();
return 0;
}
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
#include <cassert> #include <cassert>
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <functional>
#include <iostream> #include <iostream>
#include <unordered_map>
#include <vector>
#ifndef MIGRAPH_GUARD_TEST_TEST_HPP #ifndef MIGRAPH_GUARD_TEST_TEST_HPP
#define MIGRAPH_GUARD_TEST_TEST_HPP #define MIGRAPH_GUARD_TEST_TEST_HPP
...@@ -154,11 +157,75 @@ bool throws(F f, const std::string& msg = "") ...@@ -154,11 +157,75 @@ bool throws(F f, const std::string& msg = "")
} }
} }
template <class T> using string_map = std::unordered_map<std::string, std::vector<std::string>>;
void run_test()
template <class Keyword>
string_map parse(std::vector<std::string> as, Keyword keyword)
{
string_map result;
std::string flag;
for(auto&& x : as)
{
auto f = keyword(x);
if(f.empty())
{
result[flag].push_back(x);
}
else
{
flag = f.front();
result[flag]; // Ensure the flag exists
}
}
return result;
}
inline auto& get_test_cases()
{
static std::vector<std::pair<std::string, std::function<void()>>> cases;
return cases;
}
inline void add_test_case(std::string name, std::function<void()> f)
{ {
T t = {}; get_test_cases().emplace_back(name, f);
t.run(); }
struct auto_register
{
template <class F>
auto_register(const char* name, F f) noexcept
{
add_test_case(name, f);
}
};
inline void run_test_case(const std::string& name, const std::function<void()>& f)
{
std::cout << "[ RUN ] " << name << std::endl;
f();
std::cout << "[ COMPLETE ] " << name << std::endl;
}
inline void run(int argc, const char* argv[])
{
std::vector<std::string> as(argv + 1, argv + argc);
auto args = parse(as, [](auto &&) -> std::vector<std::string> { return {}; });
auto cases = args[""];
if(cases.empty())
{
for(auto&& tc : get_test_cases())
run_test_case(tc.first, tc.second);
}
else
{
std::unordered_map<std::string, std::function<void()>> m(get_test_cases().begin(),
get_test_cases().end());
for(auto&& name : cases)
run_test_case(name, m[name]);
}
} }
} // namespace test } // namespace test
...@@ -179,4 +246,24 @@ void run_test() ...@@ -179,4 +246,24 @@ void run_test()
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0) #define STATUS(...) EXPECT((__VA_ARGS__) == 0)
// NOLINTNEXTLINE
#define TEST_CAT(x, ...) TEST_PRIMITIVE_CAT(x, __VA_ARGS__)
#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__
// NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE
#define TEST_CASE(...) \
void __VA_ARGS__(); \
TEST_CASE_REGISTER(__VA_ARGS__) \
void __VA_ARGS__()
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
#endif #endif
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <string> #include <string>
#include "test.hpp" #include "test.hpp"
void literal_test() TEST_CASE(literal_test)
{ {
EXPECT(migraph::literal{1} == migraph::literal{1}); EXPECT(migraph::literal{1} == migraph::literal{1});
EXPECT(migraph::literal{1} != migraph::literal{2}); EXPECT(migraph::literal{1} != migraph::literal{2});
...@@ -25,7 +25,7 @@ void literal_test() ...@@ -25,7 +25,7 @@ void literal_test()
EXPECT(l4.empty()); EXPECT(l4.empty());
} }
void literal_os1() TEST_CASE(literal_os1)
{ {
migraph::literal l{1}; migraph::literal l{1};
std::stringstream ss; std::stringstream ss;
...@@ -33,7 +33,7 @@ void literal_os1() ...@@ -33,7 +33,7 @@ void literal_os1()
EXPECT(ss.str() == "1"); EXPECT(ss.str() == "1");
} }
void literal_os2() TEST_CASE(literal_os2)
{ {
migraph::literal l{}; migraph::literal l{};
std::stringstream ss; std::stringstream ss;
...@@ -41,7 +41,7 @@ void literal_os2() ...@@ -41,7 +41,7 @@ void literal_os2()
EXPECT(ss.str().empty()); EXPECT(ss.str().empty());
} }
void literal_os3() TEST_CASE(literal_os3)
{ {
migraph::shape s{migraph::shape::int64_type, {3}}; migraph::shape s{migraph::shape::int64_type, {3}};
migraph::literal l{s, {1, 2, 3}}; migraph::literal l{s, {1, 2, 3}};
...@@ -50,9 +50,4 @@ void literal_os3() ...@@ -50,9 +50,4 @@ void literal_os3()
EXPECT(ss.str() == "1, 2, 3"); EXPECT(ss.str() == "1, 2, 3");
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
literal_test();
literal_os1();
literal_os2();
}
...@@ -27,7 +27,7 @@ void match1() ...@@ -27,7 +27,7 @@ void match1()
EXPECT(bool{r.result == l}); EXPECT(bool{r.result == l});
} }
void match_name1() TEST_CASE(match_name1)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -39,7 +39,7 @@ void match_name1() ...@@ -39,7 +39,7 @@ void match_name1()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_name2() TEST_CASE(match_name2)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -51,7 +51,7 @@ void match_name2() ...@@ -51,7 +51,7 @@ void match_name2()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_name3() TEST_CASE(match_name3)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -63,7 +63,7 @@ void match_name3() ...@@ -63,7 +63,7 @@ void match_name3()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_arg1() TEST_CASE(match_arg1)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -75,7 +75,7 @@ void match_arg1() ...@@ -75,7 +75,7 @@ void match_arg1()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_arg2() TEST_CASE(match_arg2)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -87,7 +87,7 @@ void match_arg2() ...@@ -87,7 +87,7 @@ void match_arg2()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_arg3() TEST_CASE(match_arg3)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -99,7 +99,7 @@ void match_arg3() ...@@ -99,7 +99,7 @@ void match_arg3()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_arg4() TEST_CASE(match_arg4)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -111,7 +111,7 @@ void match_arg4() ...@@ -111,7 +111,7 @@ void match_arg4()
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
void match_arg5() TEST_CASE(match_arg5)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -123,7 +123,7 @@ void match_arg5() ...@@ -123,7 +123,7 @@ void match_arg5()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_arg6() TEST_CASE(match_arg6)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -135,7 +135,7 @@ void match_arg6() ...@@ -135,7 +135,7 @@ void match_arg6()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_arg7() TEST_CASE(match_arg7)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -148,7 +148,7 @@ void match_arg7() ...@@ -148,7 +148,7 @@ void match_arg7()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_args1() TEST_CASE(match_args1)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -161,7 +161,7 @@ void match_args1() ...@@ -161,7 +161,7 @@ void match_args1()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_args2() TEST_CASE(match_args2)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -174,7 +174,7 @@ void match_args2() ...@@ -174,7 +174,7 @@ void match_args2()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_args3() TEST_CASE(match_args3)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -186,7 +186,7 @@ void match_args3() ...@@ -186,7 +186,7 @@ void match_args3()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_args4() TEST_CASE(match_args4)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -200,7 +200,7 @@ void match_args4() ...@@ -200,7 +200,7 @@ void match_args4()
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
void match_args5() TEST_CASE(match_args5)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -213,7 +213,7 @@ void match_args5() ...@@ -213,7 +213,7 @@ void match_args5()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_args6() TEST_CASE(match_args6)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -225,7 +225,7 @@ void match_args6() ...@@ -225,7 +225,7 @@ void match_args6()
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
void match_args7() TEST_CASE(match_args7)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -239,7 +239,7 @@ void match_args7() ...@@ -239,7 +239,7 @@ void match_args7()
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
void match_either_args1() TEST_CASE(match_either_args1)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -253,7 +253,7 @@ void match_either_args1() ...@@ -253,7 +253,7 @@ void match_either_args1()
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
void match_either_args2() TEST_CASE(match_either_args2)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -267,7 +267,7 @@ void match_either_args2() ...@@ -267,7 +267,7 @@ void match_either_args2()
EXPECT(bool{r.result == sum2}); EXPECT(bool{r.result == sum2});
} }
void match_either_args3() TEST_CASE(match_either_args3)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -281,7 +281,7 @@ void match_either_args3() ...@@ -281,7 +281,7 @@ void match_either_args3()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_all_of1() TEST_CASE(match_all_of1)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -294,7 +294,7 @@ void match_all_of1() ...@@ -294,7 +294,7 @@ void match_all_of1()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_all_of2() TEST_CASE(match_all_of2)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -307,7 +307,7 @@ void match_all_of2() ...@@ -307,7 +307,7 @@ void match_all_of2()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_any_of1() TEST_CASE(match_any_of1)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -320,7 +320,7 @@ void match_any_of1() ...@@ -320,7 +320,7 @@ void match_any_of1()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_any_of2() TEST_CASE(match_any_of2)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -333,7 +333,7 @@ void match_any_of2() ...@@ -333,7 +333,7 @@ void match_any_of2()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_none_of1() TEST_CASE(match_none_of1)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -346,7 +346,7 @@ void match_none_of1() ...@@ -346,7 +346,7 @@ void match_none_of1()
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
void match_none_of2() TEST_CASE(match_none_of2)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -359,7 +359,7 @@ void match_none_of2() ...@@ -359,7 +359,7 @@ void match_none_of2()
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
void match_bind1() TEST_CASE(match_bind1)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -400,7 +400,7 @@ struct match_find_literal ...@@ -400,7 +400,7 @@ struct match_find_literal
} }
}; };
void match_finder() TEST_CASE(match_finder)
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
...@@ -410,43 +410,4 @@ void match_finder() ...@@ -410,43 +410,4 @@ void match_finder()
match::find_matches(p, match_find_sum{sum}, match_find_literal{sum}); match::find_matches(p, match_find_sum{sum}, match_find_literal{sum});
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
match1();
match_name1();
match_name2();
match_name3();
match_arg1();
match_arg2();
match_arg3();
match_arg4();
match_arg5();
match_arg6();
match_arg7();
match_args1();
match_args2();
match_args3();
match_args4();
match_args5();
match_args6();
match_args7();
match_either_args1();
match_either_args2();
match_either_args3();
match_all_of1();
match_all_of2();
match_any_of1();
match_any_of2();
match_none_of1();
match_none_of2();
match_bind1();
match_finder();
}
...@@ -43,7 +43,7 @@ bool no_allocate(const migraph::program& p) ...@@ -43,7 +43,7 @@ bool no_allocate(const migraph::program& p)
return std::none_of(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "allocate"; }); return std::none_of(p.begin(), p.end(), [](auto&& ins) { return ins.name() == "allocate"; });
} }
void test1() TEST_CASE(test1)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -55,7 +55,7 @@ void test1() ...@@ -55,7 +55,7 @@ void test1()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test2() TEST_CASE(test2)
{ {
migraph::program p; migraph::program p;
auto input = p.add_parameter("input", migraph::shape{migraph::shape::float_type, {16}}); auto input = p.add_parameter("input", migraph::shape{migraph::shape::float_type, {16}});
...@@ -69,7 +69,7 @@ void test2() ...@@ -69,7 +69,7 @@ void test2()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test3() TEST_CASE(test3)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -82,7 +82,7 @@ void test3() ...@@ -82,7 +82,7 @@ void test3()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test4() TEST_CASE(test4)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {0}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {0}});
...@@ -95,7 +95,7 @@ void test4() ...@@ -95,7 +95,7 @@ void test4()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test5() TEST_CASE(test5)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {40}});
...@@ -107,7 +107,7 @@ void test5() ...@@ -107,7 +107,7 @@ void test5()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test6() TEST_CASE(test6)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -120,7 +120,7 @@ void test6() ...@@ -120,7 +120,7 @@ void test6()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test7() TEST_CASE(test7)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -133,7 +133,7 @@ void test7() ...@@ -133,7 +133,7 @@ void test7()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test8() TEST_CASE(test8)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -146,7 +146,7 @@ void test8() ...@@ -146,7 +146,7 @@ void test8()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test9() TEST_CASE(test9)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -159,7 +159,7 @@ void test9() ...@@ -159,7 +159,7 @@ void test9()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test10() TEST_CASE(test10)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -169,7 +169,7 @@ void test10() ...@@ -169,7 +169,7 @@ void test10()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test11() TEST_CASE(test11)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -183,7 +183,7 @@ void test11() ...@@ -183,7 +183,7 @@ void test11()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test12() TEST_CASE(test12)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {40}});
...@@ -197,7 +197,7 @@ void test12() ...@@ -197,7 +197,7 @@ void test12()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test13() TEST_CASE(test13)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -211,7 +211,7 @@ void test13() ...@@ -211,7 +211,7 @@ void test13()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test14() TEST_CASE(test14)
{ {
migraph::program p; migraph::program p;
auto a3 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a3 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -225,7 +225,7 @@ void test14() ...@@ -225,7 +225,7 @@ void test14()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test15() TEST_CASE(test15)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -239,7 +239,7 @@ void test15() ...@@ -239,7 +239,7 @@ void test15()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test16() TEST_CASE(test16)
{ {
migraph::program p; migraph::program p;
auto a1 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {8}})); auto a1 = p.add_literal(migraph::generate_literal({migraph::shape::float_type, {8}}));
...@@ -253,7 +253,7 @@ void test16() ...@@ -253,7 +253,7 @@ void test16()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test17() TEST_CASE(test17)
{ {
migraph::program p; migraph::program p;
auto a3 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a3 = add_alloc(p, {migraph::shape::float_type, {40}});
...@@ -267,7 +267,7 @@ void test17() ...@@ -267,7 +267,7 @@ void test17()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test18() TEST_CASE(test18)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -281,7 +281,7 @@ void test18() ...@@ -281,7 +281,7 @@ void test18()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test19() TEST_CASE(test19)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -295,7 +295,7 @@ void test19() ...@@ -295,7 +295,7 @@ void test19()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test20() TEST_CASE(test20)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {32}});
...@@ -309,7 +309,7 @@ void test20() ...@@ -309,7 +309,7 @@ void test20()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test21() TEST_CASE(test21)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {32}});
...@@ -323,7 +323,7 @@ void test21() ...@@ -323,7 +323,7 @@ void test21()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test22() TEST_CASE(test22)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {32}});
...@@ -337,7 +337,7 @@ void test22() ...@@ -337,7 +337,7 @@ void test22()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test23() TEST_CASE(test23)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -351,7 +351,7 @@ void test23() ...@@ -351,7 +351,7 @@ void test23()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test24() TEST_CASE(test24)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {32}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {32}});
...@@ -365,7 +365,7 @@ void test24() ...@@ -365,7 +365,7 @@ void test24()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test25() TEST_CASE(test25)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -379,7 +379,7 @@ void test25() ...@@ -379,7 +379,7 @@ void test25()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test26() TEST_CASE(test26)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -393,7 +393,7 @@ void test26() ...@@ -393,7 +393,7 @@ void test26()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test27() TEST_CASE(test27)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -405,7 +405,7 @@ void test27() ...@@ -405,7 +405,7 @@ void test27()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test28() TEST_CASE(test28)
{ {
migraph::program p; migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {8}}); auto output = p.add_parameter("output", {migraph::shape::float_type, {8}});
...@@ -419,7 +419,7 @@ void test28() ...@@ -419,7 +419,7 @@ void test28()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test29() TEST_CASE(test29)
{ {
migraph::program p; migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {8}}); auto output = p.add_parameter("output", {migraph::shape::float_type, {8}});
...@@ -434,7 +434,7 @@ void test29() ...@@ -434,7 +434,7 @@ void test29()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test30() TEST_CASE(test30)
{ {
migraph::program p; migraph::program p;
auto output = p.add_parameter("x", {migraph::shape::float_type, {8}}); auto output = p.add_parameter("x", {migraph::shape::float_type, {8}});
...@@ -449,7 +449,7 @@ void test30() ...@@ -449,7 +449,7 @@ void test30()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test31() TEST_CASE(test31)
{ {
migraph::program p; migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {8}}); auto output = p.add_parameter("output", {migraph::shape::float_type, {8}});
...@@ -463,7 +463,7 @@ void test31() ...@@ -463,7 +463,7 @@ void test31()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test32() TEST_CASE(test32)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -477,7 +477,7 @@ void test32() ...@@ -477,7 +477,7 @@ void test32()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test33() TEST_CASE(test33)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {8}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {8}});
...@@ -491,7 +491,7 @@ void test33() ...@@ -491,7 +491,7 @@ void test33()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test34() TEST_CASE(test34)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {40}});
...@@ -505,7 +505,7 @@ void test34() ...@@ -505,7 +505,7 @@ void test34()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test35() TEST_CASE(test35)
{ {
migraph::program p; migraph::program p;
auto a1 = add_alloc(p, {migraph::shape::float_type, {40}}); auto a1 = add_alloc(p, {migraph::shape::float_type, {40}});
...@@ -519,7 +519,7 @@ void test35() ...@@ -519,7 +519,7 @@ void test35()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test36() TEST_CASE(test36)
{ {
migraph::program p; migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {20}}); auto output = p.add_parameter("output", {migraph::shape::float_type, {20}});
...@@ -536,7 +536,7 @@ void test36() ...@@ -536,7 +536,7 @@ void test36()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test37() TEST_CASE(test37)
{ {
migraph::program p; migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {20}}); auto output = p.add_parameter("output", {migraph::shape::float_type, {20}});
...@@ -553,7 +553,7 @@ void test37() ...@@ -553,7 +553,7 @@ void test37()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void test38() TEST_CASE(test38)
{ {
migraph::program p; migraph::program p;
auto output = p.add_parameter("output", {migraph::shape::float_type, {1, 64, 56, 56}}); auto output = p.add_parameter("output", {migraph::shape::float_type, {1, 64, 56, 56}});
...@@ -598,7 +598,7 @@ void test38() ...@@ -598,7 +598,7 @@ void test38()
CHECK(no_allocate(p)); CHECK(no_allocate(p));
} }
void literal_test() TEST_CASE(literal_test)
{ {
migraph::program p; migraph::program p;
auto lit = generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}}); auto lit = generate_literal(migraph::shape{migraph::shape::float_type, {4, 3, 3, 3}});
...@@ -608,46 +608,4 @@ void literal_test() ...@@ -608,46 +608,4 @@ void literal_test()
CHECK(lit == result); CHECK(lit == result);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
test1();
test2();
test3();
test4();
test5();
test6();
test7();
test8();
test9();
test10();
test11();
test12();
test13();
test14();
test15();
test16();
test17();
test18();
test19();
test20();
test21();
test22();
test23();
test24();
test25();
test26();
test27();
test28();
test29();
test30();
test31();
test32();
test33();
test34();
test35();
test36();
test37();
test38();
literal_test();
}
...@@ -52,7 +52,7 @@ void throws_shape(const migraph::shape&, Ts...) ...@@ -52,7 +52,7 @@ void throws_shape(const migraph::shape&, Ts...)
"An expected shape should not be passed to throws_shape function"); "An expected shape should not be passed to throws_shape function");
} }
void batch_norm_inference_shape() TEST_CASE(batch_norm_inference_shape)
{ {
const size_t channels = 3; const size_t channels = 3;
migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}}; migraph::shape s{migraph::shape::float_type, {4, channels, 3, 3}};
...@@ -62,7 +62,7 @@ void batch_norm_inference_shape() ...@@ -62,7 +62,7 @@ void batch_norm_inference_shape()
throws_shape(migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars); throws_shape(migraph::op::batch_norm_inference{}, s, vars, vars, vars, vars, vars);
} }
void convolution_shape() TEST_CASE(convolution_shape)
{ {
migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}}; migraph::shape output{migraph::shape::float_type, {4, 4, 1, 1}};
migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}}; migraph::shape input{migraph::shape::float_type, {4, 3, 3, 3}};
...@@ -76,7 +76,7 @@ void convolution_shape() ...@@ -76,7 +76,7 @@ void convolution_shape()
throws_shape(migraph::op::convolution{}, input2, weights); throws_shape(migraph::op::convolution{}, input2, weights);
} }
void transpose_shape() TEST_CASE(transpose_shape)
{ {
migraph::shape input{migraph::shape::float_type, {2, 2}}; migraph::shape input{migraph::shape::float_type, {2, 2}};
migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}}; migraph::shape output{migraph::shape::float_type, {2, 2}, {1, 2}};
...@@ -85,7 +85,7 @@ void transpose_shape() ...@@ -85,7 +85,7 @@ void transpose_shape()
throws_shape(migraph::op::transpose{{1, 2}}, input); throws_shape(migraph::op::transpose{{1, 2}}, input);
} }
void contiguous_shape() TEST_CASE(contiguous_shape)
{ {
migraph::shape output{migraph::shape::float_type, {2, 2}}; migraph::shape output{migraph::shape::float_type, {2, 2}};
migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}}; migraph::shape input{migraph::shape::float_type, {2, 2}, {1, 2}};
...@@ -96,7 +96,7 @@ void contiguous_shape() ...@@ -96,7 +96,7 @@ void contiguous_shape()
expect_shape(single, migraph::op::contiguous{}, single); expect_shape(single, migraph::op::contiguous{}, single);
} }
void reshape_shape() TEST_CASE(reshape_shape)
{ {
migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}}; migraph::shape input{migraph::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape : for(auto&& new_shape :
...@@ -114,7 +114,7 @@ void reshape_shape() ...@@ -114,7 +114,7 @@ void reshape_shape()
} }
} }
void flatten_shape() TEST_CASE(flatten_shape)
{ {
migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}}; migraph::shape input{migraph::shape::float_type, {2, 4, 6, 8}};
expect_shape(migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}}, expect_shape(migraph::shape{migraph::shape::float_type, {1, 2 * 4 * 6 * 8}},
...@@ -132,7 +132,7 @@ void flatten_shape() ...@@ -132,7 +132,7 @@ void flatten_shape()
throws_shape(migraph::op::flatten{5}, input); throws_shape(migraph::op::flatten{5}, input);
} }
void slice_shape() TEST_CASE(slice_shape)
{ {
migraph::shape input{migraph::shape::int32_type, {2, 2, 3}}; migraph::shape input{migraph::shape::int32_type, {2, 2, 3}};
expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}}, expect_shape(migraph::shape{migraph::shape::int32_type, {2, 2, 2}, {6, 3, 1}},
...@@ -145,13 +145,4 @@ void slice_shape() ...@@ -145,13 +145,4 @@ void slice_shape()
migraph::op::slice{{2}, {2}, {10}}, migraph::op::slice{{2}, {2}, {10}},
input); input);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
batch_norm_inference_shape();
convolution_shape();
transpose_shape();
contiguous_shape();
reshape_shape();
flatten_shape();
slice_shape();
}
...@@ -43,7 +43,7 @@ struct simple_operation_no_print ...@@ -43,7 +43,7 @@ struct simple_operation_no_print
} }
}; };
void operation_copy_test() TEST_CASE(operation_copy_test)
{ {
simple_operation s{}; simple_operation s{};
migraph::operation op1 = s; // NOLINT migraph::operation op1 = s; // NOLINT
...@@ -54,7 +54,7 @@ void operation_copy_test() ...@@ -54,7 +54,7 @@ void operation_copy_test()
EXPECT(op2 == op1); EXPECT(op2 == op1);
} }
void operation_equal_test() TEST_CASE(operation_equal_test)
{ {
simple_operation s{}; simple_operation s{};
migraph::operation op1 = s; migraph::operation op1 = s;
...@@ -72,7 +72,7 @@ struct not_operation ...@@ -72,7 +72,7 @@ struct not_operation
{ {
}; };
void operation_any_cast() TEST_CASE(operation_any_cast)
{ {
migraph::operation op1 = simple_operation{}; migraph::operation op1 = simple_operation{};
EXPECT(migraph::any_cast<simple_operation>(op1).data == 1); EXPECT(migraph::any_cast<simple_operation>(op1).data == 1);
...@@ -83,7 +83,7 @@ void operation_any_cast() ...@@ -83,7 +83,7 @@ void operation_any_cast()
EXPECT(migraph::any_cast<not_operation*>(&op2) == nullptr); EXPECT(migraph::any_cast<not_operation*>(&op2) == nullptr);
} }
void operation_print() TEST_CASE(operation_print)
{ {
migraph::operation op = simple_operation{}; migraph::operation op = simple_operation{};
std::stringstream ss; std::stringstream ss;
...@@ -92,7 +92,7 @@ void operation_print() ...@@ -92,7 +92,7 @@ void operation_print()
EXPECT(s == "simple[1]"); EXPECT(s == "simple[1]");
} }
void operation_default_print() TEST_CASE(operation_default_print)
{ {
migraph::operation op = simple_operation_no_print{}; migraph::operation op = simple_operation_no_print{};
std::stringstream ss; std::stringstream ss;
...@@ -101,11 +101,4 @@ void operation_default_print() ...@@ -101,11 +101,4 @@ void operation_default_print()
EXPECT(s == "simple"); EXPECT(s == "simple");
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
operation_copy_test();
operation_equal_test();
operation_any_cast();
operation_print();
operation_default_print();
}
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <test.hpp> #include <test.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
void simple_alias() TEST_CASE(simple_alias)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(1); auto l = p.add_literal(1);
...@@ -12,7 +12,7 @@ void simple_alias() ...@@ -12,7 +12,7 @@ void simple_alias()
EXPECT(bool{migraph::instruction::get_output_alias(p1) == l}); EXPECT(bool{migraph::instruction::get_output_alias(p1) == l});
} }
void cascade_alias() TEST_CASE(cascade_alias)
{ {
migraph::program p; migraph::program p;
auto l = p.add_literal(1); auto l = p.add_literal(1);
...@@ -25,7 +25,7 @@ void cascade_alias() ...@@ -25,7 +25,7 @@ void cascade_alias()
EXPECT(bool{migraph::instruction::get_output_alias(p3) == l}); EXPECT(bool{migraph::instruction::get_output_alias(p3) == l});
} }
void no_alias() TEST_CASE(no_alias)
{ {
migraph::program p; migraph::program p;
auto x = p.add_literal(1); auto x = p.add_literal(1);
...@@ -34,9 +34,4 @@ void no_alias() ...@@ -34,9 +34,4 @@ void no_alias()
EXPECT(bool{migraph::instruction::get_output_alias(sum) == sum}); EXPECT(bool{migraph::instruction::get_output_alias(sum) == sum});
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
simple_alias();
cascade_alias();
no_alias();
}
...@@ -20,11 +20,11 @@ migraph::program create_program() ...@@ -20,11 +20,11 @@ migraph::program create_program()
return p; return p;
} }
void program_equality() TEST_CASE(program_equality)
{ {
migraph::program x = create_program(); migraph::program x = create_program();
migraph::program y = create_program(); migraph::program y = create_program();
EXPECT(x == y); EXPECT(x == y);
} }
int main() { program_equality(); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
#include <numeric> #include <numeric>
#include "test.hpp" #include "test.hpp"
void test_shape_default() TEST_CASE(test_shape_default)
{ {
migraph::shape s{}; migraph::shape s{};
EXPECT(s.elements() == 0); EXPECT(s.elements() == 0);
EXPECT(s.bytes() == 0); EXPECT(s.bytes() == 0);
} }
void test_shape_assign() TEST_CASE(test_shape_assign)
{ {
migraph::shape s1{migraph::shape::float_type, {100, 32, 8, 8}}; migraph::shape s1{migraph::shape::float_type, {100, 32, 8, 8}};
migraph::shape s2 = s1; // NOLINT migraph::shape s2 = s1; // NOLINT
...@@ -20,7 +20,7 @@ void test_shape_assign() ...@@ -20,7 +20,7 @@ void test_shape_assign()
EXPECT(!(s1 != s2)); EXPECT(!(s1 != s2));
} }
void test_shape_packed_default() TEST_CASE(test_shape_packed_default)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}}; migraph::shape s{migraph::shape::float_type, {2, 2}};
EXPECT(s.standard()); EXPECT(s.standard());
...@@ -29,7 +29,7 @@ void test_shape_packed_default() ...@@ -29,7 +29,7 @@ void test_shape_packed_default()
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
void test_shape_packed() TEST_CASE(test_shape_packed)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}}; migraph::shape s{migraph::shape::float_type, {2, 2}, {2, 1}};
EXPECT(s.standard()); EXPECT(s.standard());
...@@ -38,7 +38,7 @@ void test_shape_packed() ...@@ -38,7 +38,7 @@ void test_shape_packed()
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
void test_shape_transposed() TEST_CASE(test_shape_transposed)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}}; migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard()); EXPECT(not s.standard());
...@@ -47,7 +47,7 @@ void test_shape_transposed() ...@@ -47,7 +47,7 @@ void test_shape_transposed()
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
void test_shape_broadcasted() TEST_CASE(test_shape_broadcasted)
{ {
migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 0}}; migraph::shape s{migraph::shape::float_type, {2, 2}, {1, 0}};
EXPECT(not s.standard()); EXPECT(not s.standard());
...@@ -56,7 +56,7 @@ void test_shape_broadcasted() ...@@ -56,7 +56,7 @@ void test_shape_broadcasted()
EXPECT(s.broadcasted()); EXPECT(s.broadcasted());
} }
void test_shape_default_copy() TEST_CASE(test_shape_default_copy)
{ {
migraph::shape s1{}; migraph::shape s1{};
migraph::shape s2{}; migraph::shape s2{};
...@@ -64,7 +64,7 @@ void test_shape_default_copy() ...@@ -64,7 +64,7 @@ void test_shape_default_copy()
EXPECT(!(s1 != s2)); EXPECT(!(s1 != s2));
} }
void test_shape4() TEST_CASE(test_shape4)
{ {
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}}; migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}};
EXPECT(s.standard()); EXPECT(s.standard());
...@@ -97,7 +97,7 @@ void test_shape4() ...@@ -97,7 +97,7 @@ void test_shape4()
EXPECT(s.index(s.elements() - 1) == s.elements() - 1); EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
} }
void test_shape42() TEST_CASE(test_shape42)
{ {
migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}, {2048, 64, 8, 1}}; migraph::shape s{migraph::shape::float_type, {100, 32, 8, 8}, {2048, 64, 8, 1}};
EXPECT(s.standard()); EXPECT(s.standard());
...@@ -130,7 +130,7 @@ void test_shape42() ...@@ -130,7 +130,7 @@ void test_shape42()
EXPECT(s.index(s.elements() - 1) == s.elements() - 1); EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
} }
void test_shape4_transposed() TEST_CASE(test_shape4_transposed)
{ {
migraph::shape s{migraph::shape::float_type, {32, 100, 8, 8}, {64, 2048, 8, 1}}; migraph::shape s{migraph::shape::float_type, {32, 100, 8, 8}, {64, 2048, 8, 1}};
EXPECT(s.transposed()); EXPECT(s.transposed());
...@@ -163,7 +163,7 @@ void test_shape4_transposed() ...@@ -163,7 +163,7 @@ void test_shape4_transposed()
EXPECT(s.index(s.elements() - 1) == s.elements() - 1); EXPECT(s.index(s.elements() - 1) == s.elements() - 1);
} }
void test_shape4_nonpacked() TEST_CASE(test_shape4_nonpacked)
{ {
std::vector<std::size_t> lens = {100, 32, 8, 8}; std::vector<std::size_t> lens = {100, 32, 8, 8};
std::array<std::size_t, 4> offsets = {{5, 10, 0, 6}}; std::array<std::size_t, 4> offsets = {{5, 10, 0, 6}};
...@@ -206,17 +206,4 @@ void test_shape4_nonpacked() ...@@ -206,17 +206,4 @@ void test_shape4_nonpacked()
EXPECT(s.index(s.elements() - 1) == 469273); EXPECT(s.index(s.elements() - 1) == 469273);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
test_shape_default();
test_shape_assign();
test_shape_packed_default();
test_shape_packed();
test_shape_transposed();
test_shape_broadcasted();
test_shape_default_copy();
test_shape4();
test_shape42();
test_shape4_transposed();
test_shape4_nonpacked();
}
...@@ -14,7 +14,7 @@ struct simplify_algebra_target ...@@ -14,7 +14,7 @@ struct simplify_algebra_target
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
void simplify_add1() TEST_CASE(simplify_add1)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -43,7 +43,7 @@ void simplify_add1() ...@@ -43,7 +43,7 @@ void simplify_add1()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
void simplify_add2() TEST_CASE(simplify_add2)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -72,7 +72,7 @@ void simplify_add2() ...@@ -72,7 +72,7 @@ void simplify_add2()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
void simplify_add3() TEST_CASE(simplify_add3)
{ {
migraph::program p1; migraph::program p1;
{ {
...@@ -99,6 +99,7 @@ void simplify_add3() ...@@ -99,6 +99,7 @@ void simplify_add3()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
// TODO: Add test case
void simplify_add4() void simplify_add4()
{ {
migraph::program p1; migraph::program p1;
...@@ -128,10 +129,4 @@ void simplify_add4() ...@@ -128,10 +129,4 @@ void simplify_add4()
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
simplify_add1();
simplify_add2();
simplify_add3();
// simplify_add4();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment