Commit 2ca12a73 authored by Khalique's avatar Khalique
Browse files

updated onnx tests

parent 5df4d2a8
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
void pytorch_conv_bias_test() TEST_CASE(pytorch_conv_bias_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
...@@ -22,7 +22,7 @@ void pytorch_conv_bias_test() ...@@ -22,7 +22,7 @@ void pytorch_conv_bias_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void pytorch_conv_relu_maxpool() TEST_CASE(pytorch_conv_relu_maxpool)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
...@@ -39,7 +39,7 @@ void pytorch_conv_relu_maxpool() ...@@ -39,7 +39,7 @@ void pytorch_conv_relu_maxpool()
EXPECT(p == prog); EXPECT(p == prog);
} }
void pytorch_conv_bn_relu_maxpool() TEST_CASE(pytorch_conv_bn_relu_maxpool)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
...@@ -62,7 +62,7 @@ void pytorch_conv_bn_relu_maxpool() ...@@ -62,7 +62,7 @@ void pytorch_conv_bn_relu_maxpool()
EXPECT(p == prog); EXPECT(p == prog);
} }
void pytorch_conv_relu_maxpool_x2() TEST_CASE(pytorch_conv_relu_maxpool_x2)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}}); auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
...@@ -88,7 +88,7 @@ void pytorch_conv_relu_maxpool_x2() ...@@ -88,7 +88,7 @@ void pytorch_conv_relu_maxpool_x2()
EXPECT(p == prog); EXPECT(p == prog);
} }
void leaky_relu_test() TEST_CASE(leaky_relu_test)
{ {
migraphx::program p; migraphx::program p;
float alpha = 0.01f; float alpha = 0.01f;
...@@ -100,7 +100,7 @@ void leaky_relu_test() ...@@ -100,7 +100,7 @@ void leaky_relu_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void imagescaler_test() TEST_CASE(imagescaler_test)
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}}; migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
...@@ -118,7 +118,7 @@ void imagescaler_test() ...@@ -118,7 +118,7 @@ void imagescaler_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void globalavgpool_test() TEST_CASE(globalavgpool_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
...@@ -132,7 +132,7 @@ void globalavgpool_test() ...@@ -132,7 +132,7 @@ void globalavgpool_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void globalmaxpool_test() TEST_CASE(globalmaxpool_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
...@@ -146,7 +146,7 @@ void globalmaxpool_test() ...@@ -146,7 +146,7 @@ void globalmaxpool_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void transpose_test() TEST_CASE(transpose_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
...@@ -158,7 +158,7 @@ void transpose_test() ...@@ -158,7 +158,7 @@ void transpose_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void dropout_test() TEST_CASE(dropout_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}}); auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}});
...@@ -169,7 +169,7 @@ void dropout_test() ...@@ -169,7 +169,7 @@ void dropout_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void sum_test() TEST_CASE(sum_test)
{ {
migraphx::program p; migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
...@@ -181,7 +181,7 @@ void sum_test() ...@@ -181,7 +181,7 @@ void sum_test()
auto prog = migraphx::parse_onnx("sum_test.onnx"); auto prog = migraphx::parse_onnx("sum_test.onnx");
} }
void sin_test() TEST_CASE(sin_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
...@@ -191,7 +191,7 @@ void sin_test() ...@@ -191,7 +191,7 @@ void sin_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void cos_test() TEST_CASE(cos_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
...@@ -201,7 +201,7 @@ void cos_test() ...@@ -201,7 +201,7 @@ void cos_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void tan_test() TEST_CASE(tan_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
...@@ -211,7 +211,7 @@ void tan_test() ...@@ -211,7 +211,7 @@ void tan_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void sinh_test() TEST_CASE(sinh_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
...@@ -222,7 +222,7 @@ void sinh_test() ...@@ -222,7 +222,7 @@ void sinh_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void cosh_test() TEST_CASE(cosh_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
...@@ -233,7 +233,7 @@ void cosh_test() ...@@ -233,7 +233,7 @@ void cosh_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void tanh_test() TEST_CASE(tanh_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
...@@ -244,7 +244,7 @@ void tanh_test() ...@@ -244,7 +244,7 @@ void tanh_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void asin_test() TEST_CASE(asin_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
...@@ -255,7 +255,7 @@ void asin_test() ...@@ -255,7 +255,7 @@ void asin_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void max_test() TEST_CASE(max_test)
{ {
migraphx::program p; migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
...@@ -267,7 +267,7 @@ void max_test() ...@@ -267,7 +267,7 @@ void max_test()
auto prog = migraphx::parse_onnx("max_test.onnx"); auto prog = migraphx::parse_onnx("max_test.onnx");
} }
void acos_test() TEST_CASE(acos_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
...@@ -278,7 +278,7 @@ void acos_test() ...@@ -278,7 +278,7 @@ void acos_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void min_test() TEST_CASE(min_test)
{ {
migraphx::program p; migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}}); auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
...@@ -290,7 +290,7 @@ void min_test() ...@@ -290,7 +290,7 @@ void min_test()
auto prog = migraphx::parse_onnx("min_test.onnx"); auto prog = migraphx::parse_onnx("min_test.onnx");
} }
void atan_test() TEST_CASE(atan_test)
{ {
migraphx::program p; migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}}); auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
...@@ -301,7 +301,7 @@ void atan_test() ...@@ -301,7 +301,7 @@ void atan_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void add_bcast_test() TEST_CASE(add_bcast_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
...@@ -314,7 +314,7 @@ void add_bcast_test() ...@@ -314,7 +314,7 @@ void add_bcast_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void implicit_bcast_test() TEST_CASE(implicit_bcast_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
...@@ -328,29 +328,31 @@ void implicit_bcast_test() ...@@ -328,29 +328,31 @@ void implicit_bcast_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void unknown_test() TEST_CASE(unknown_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
p.add_instruction(migraphx::unknown{"Unknown"}, l0, l1); auto l2 = p.add_instruction(migraphx::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::unknown{"Unknown"}); p.add_instruction(migraphx::unknown{"Unknown"}, l2);
auto prog = migraphx::parse_onnx("unknown_test.onnx"); auto prog = migraphx::parse_onnx("unknown_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void softmax_test() TEST_CASE(softmax_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
p.add_instruction(migraphx::op::softmax{}, l0); auto r = p.add_instruction(migraphx::op::reshape{{1, 3, 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{1, 3}}, s);
auto prog = migraphx::parse_onnx("softmax_test.onnx"); auto prog = migraphx::parse_onnx("softmax_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
void reshape_test() TEST_CASE(reshape_test)
{ {
migraphx::program p; migraphx::program p;
migraphx::op::reshape op; migraphx::op::reshape op;
...@@ -366,7 +368,7 @@ void reshape_test() ...@@ -366,7 +368,7 @@ void reshape_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void flatten_test() TEST_CASE(flatten_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
...@@ -377,7 +379,7 @@ void flatten_test() ...@@ -377,7 +379,7 @@ void flatten_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void squeeze_unsqueeze_test() TEST_CASE(squeeze_unsqueeze_test)
{ {
migraphx::program p; migraphx::program p;
std::vector<int64_t> squeeze_axes{0, 2, 3, 5}; std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
...@@ -391,7 +393,7 @@ void squeeze_unsqueeze_test() ...@@ -391,7 +393,7 @@ void squeeze_unsqueeze_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void concat_test() TEST_CASE(concat_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}});
...@@ -402,7 +404,7 @@ void concat_test() ...@@ -402,7 +404,7 @@ void concat_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void slice_test() TEST_CASE(slice_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
...@@ -412,7 +414,7 @@ void slice_test() ...@@ -412,7 +414,7 @@ void slice_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void constant_test() TEST_CASE(constant_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}}); p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}});
...@@ -421,7 +423,7 @@ void constant_test() ...@@ -421,7 +423,7 @@ void constant_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void gemm_test() TEST_CASE(gemm_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
...@@ -437,7 +439,7 @@ void gemm_test() ...@@ -437,7 +439,7 @@ void gemm_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void add_scalar_test() TEST_CASE(add_scalar_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
...@@ -451,39 +453,5 @@ void add_scalar_test() ...@@ -451,39 +453,5 @@ void add_scalar_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
int main() int main(int argc, const char* argv[]) { test::run(argc, argv); }
{
pytorch_conv_bias_test();
pytorch_conv_relu_maxpool();
pytorch_conv_bn_relu_maxpool();
pytorch_conv_relu_maxpool_x2();
leaky_relu_test();
imagescaler_test();
globalavgpool_test();
globalmaxpool_test();
transpose_test();
dropout_test();
sum_test();
max_test();
min_test();
sin_test();
cos_test();
tan_test();
sinh_test();
cosh_test();
tanh_test();
asin_test();
acos_test();
atan_test();
add_bcast_test();
implicit_bcast_test();
unknown_test();
reshape_test();
flatten_test();
squeeze_unsqueeze_test();
concat_test();
slice_test();
constant_test();
gemm_test();
add_scalar_test();
}
unknown-example:| unknown-example:
 
0 0
12"Unknown 12"Unknown
"Unknown test-unknownZ
2"Unknown test-unknownZ
0 0
 
 
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment