Commit 2ca12a73 authored by Khalique's avatar Khalique
Browse files

updated onnx tests

parent 5df4d2a8
......@@ -7,7 +7,7 @@
#include <migraphx/onnx.hpp>
#include "test.hpp"
void pytorch_conv_bias_test()
TEST_CASE(pytorch_conv_bias_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
......@@ -22,7 +22,7 @@ void pytorch_conv_bias_test()
EXPECT(p == prog);
}
void pytorch_conv_relu_maxpool()
TEST_CASE(pytorch_conv_relu_maxpool)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
......@@ -39,7 +39,7 @@ void pytorch_conv_relu_maxpool()
EXPECT(p == prog);
}
void pytorch_conv_bn_relu_maxpool()
TEST_CASE(pytorch_conv_bn_relu_maxpool)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
......@@ -62,7 +62,7 @@ void pytorch_conv_bn_relu_maxpool()
EXPECT(p == prog);
}
void pytorch_conv_relu_maxpool_x2()
TEST_CASE(pytorch_conv_relu_maxpool_x2)
{
migraphx::program p;
auto l0 = p.add_parameter("0", {migraphx::shape::float_type, {1, 3, 32, 32}});
......@@ -88,7 +88,7 @@ void pytorch_conv_relu_maxpool_x2()
EXPECT(p == prog);
}
void leaky_relu_test()
TEST_CASE(leaky_relu_test)
{
migraphx::program p;
float alpha = 0.01f;
......@@ -100,7 +100,7 @@ void leaky_relu_test()
EXPECT(p == prog);
}
void imagescaler_test()
TEST_CASE(imagescaler_test)
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3, 16, 16}};
......@@ -118,7 +118,7 @@ void imagescaler_test()
EXPECT(p == prog);
}
void globalavgpool_test()
TEST_CASE(globalavgpool_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
......@@ -132,7 +132,7 @@ void globalavgpool_test()
EXPECT(p == prog);
}
void globalmaxpool_test()
TEST_CASE(globalmaxpool_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
......@@ -146,7 +146,7 @@ void globalmaxpool_test()
EXPECT(p == prog);
}
void transpose_test()
TEST_CASE(transpose_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
......@@ -158,7 +158,7 @@ void transpose_test()
EXPECT(p == prog);
}
void dropout_test()
TEST_CASE(dropout_test)
{
migraphx::program p;
auto input = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 2, 2}});
......@@ -169,7 +169,7 @@ void dropout_test()
EXPECT(p == prog);
}
void sum_test()
TEST_CASE(sum_test)
{
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
......@@ -181,7 +181,7 @@ void sum_test()
auto prog = migraphx::parse_onnx("sum_test.onnx");
}
void sin_test()
TEST_CASE(sin_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
......@@ -191,7 +191,7 @@ void sin_test()
EXPECT(p == prog);
}
void cos_test()
TEST_CASE(cos_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
......@@ -201,7 +201,7 @@ void cos_test()
EXPECT(p == prog);
}
void tan_test()
TEST_CASE(tan_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
......@@ -211,7 +211,7 @@ void tan_test()
EXPECT(p == prog);
}
void sinh_test()
TEST_CASE(sinh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
......@@ -222,7 +222,7 @@ void sinh_test()
EXPECT(p == prog);
}
void cosh_test()
TEST_CASE(cosh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
......@@ -233,7 +233,7 @@ void cosh_test()
EXPECT(p == prog);
}
void tanh_test()
TEST_CASE(tanh_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1}});
......@@ -244,7 +244,7 @@ void tanh_test()
EXPECT(p == prog);
}
void asin_test()
TEST_CASE(asin_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
......@@ -255,7 +255,7 @@ void asin_test()
EXPECT(p == prog);
}
void max_test()
TEST_CASE(max_test)
{
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
......@@ -267,7 +267,7 @@ void max_test()
auto prog = migraphx::parse_onnx("max_test.onnx");
}
void acos_test()
TEST_CASE(acos_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
......@@ -278,7 +278,7 @@ void acos_test()
EXPECT(p == prog);
}
void min_test()
TEST_CASE(min_test)
{
migraphx::program p;
auto input0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3}});
......@@ -290,7 +290,7 @@ void min_test()
auto prog = migraphx::parse_onnx("min_test.onnx");
}
void atan_test()
TEST_CASE(atan_test)
{
migraphx::program p;
auto input = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10}});
......@@ -301,7 +301,7 @@ void atan_test()
EXPECT(p == prog);
}
void add_bcast_test()
TEST_CASE(add_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
......@@ -314,7 +314,7 @@ void add_bcast_test()
EXPECT(p == prog);
}
void implicit_bcast_test()
TEST_CASE(implicit_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
......@@ -328,29 +328,31 @@ void implicit_bcast_test()
EXPECT(p == prog);
}
void unknown_test()
TEST_CASE(unknown_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
p.add_instruction(migraphx::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::unknown{"Unknown"});
auto l2 = p.add_instruction(migraphx::unknown{"Unknown"}, l0, l1);
p.add_instruction(migraphx::unknown{"Unknown"}, l2);
auto prog = migraphx::parse_onnx("unknown_test.onnx");
EXPECT(p == prog);
}
void softmax_test()
TEST_CASE(softmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
p.add_instruction(migraphx::op::softmax{}, l0);
auto r = p.add_instruction(migraphx::op::reshape{{1, 3, 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{1, 3}}, s);
auto prog = migraphx::parse_onnx("softmax_test.onnx");
EXPECT(p == prog);
}
void reshape_test()
TEST_CASE(reshape_test)
{
migraphx::program p;
migraphx::op::reshape op;
......@@ -366,7 +368,7 @@ void reshape_test()
EXPECT(p == prog);
}
void flatten_test()
TEST_CASE(flatten_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
......@@ -377,7 +379,7 @@ void flatten_test()
EXPECT(p == prog);
}
void squeeze_unsqueeze_test()
TEST_CASE(squeeze_unsqueeze_test)
{
migraphx::program p;
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
......@@ -391,7 +393,7 @@ void squeeze_unsqueeze_test()
EXPECT(p == prog);
}
void concat_test()
TEST_CASE(concat_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}});
......@@ -402,7 +404,7 @@ void concat_test()
EXPECT(p == prog);
}
void slice_test()
TEST_CASE(slice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
......@@ -412,7 +414,7 @@ void slice_test()
EXPECT(p == prog);
}
void constant_test()
TEST_CASE(constant_test)
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}});
......@@ -421,7 +423,7 @@ void constant_test()
EXPECT(p == prog);
}
void gemm_test()
TEST_CASE(gemm_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 7}});
......@@ -437,7 +439,7 @@ void gemm_test()
EXPECT(p == prog);
}
void add_scalar_test()
TEST_CASE(add_scalar_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
......@@ -451,39 +453,5 @@ void add_scalar_test()
EXPECT(p == prog);
}
int main()
{
pytorch_conv_bias_test();
pytorch_conv_relu_maxpool();
pytorch_conv_bn_relu_maxpool();
pytorch_conv_relu_maxpool_x2();
leaky_relu_test();
imagescaler_test();
globalavgpool_test();
globalmaxpool_test();
transpose_test();
dropout_test();
sum_test();
max_test();
min_test();
sin_test();
cos_test();
tan_test();
sinh_test();
cosh_test();
tanh_test();
asin_test();
acos_test();
atan_test();
add_bcast_test();
implicit_bcast_test();
unknown_test();
reshape_test();
flatten_test();
squeeze_unsqueeze_test();
concat_test();
slice_test();
constant_test();
gemm_test();
add_scalar_test();
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
unknown-example:|
unknown-example:

0
12"Unknown
"Unknown test-unknownZ
2"Unknown test-unknownZ
0


......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment