Commit 13c9ed9c authored by Khalique's avatar Khalique
Browse files

add flatten, squeeze/unsqueeze, concat, slice, constant onnx tests

parent ddeeedb9
......@@ -316,7 +316,7 @@ struct onnx_parser
instruction_ref
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
uint64_t axis = 0;
uint64_t axis = 1;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
......
flatten-example:„

02"Flatten*
axis 

03"Flatten test-flattenZ
0




b
2


b
3


<B
\ No newline at end of file
......@@ -366,6 +366,60 @@ void reshape_test()
EXPECT(p == prog);
}
void flatten_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
p.add_instruction(migraphx::op::flatten{1}, l0);
p.add_instruction(migraphx::op::flatten{2}, l0);
auto prog = migraphx::parse_onnx("flatten_test.onnx");
EXPECT(p == prog);
}
void squeeze_unsqueeze_test()
{
migraphx::program p;
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 1, 1, 2, 1}});
auto l1 = p.add_instruction(migraphx::op::squeeze{squeeze_axes}, l0);
p.add_instruction(migraphx::op::unsqueeze{unsqueeze_axes}, l1);
auto prog = migraphx::parse_onnx("squeeze_unsqueeze_test.onnx");
EXPECT(p == prog);
}
void concat_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {7, 4, 3}});
p.add_instruction(migraphx::op::concat{0}, l0, l1);
auto prog = migraphx::parse_onnx("concat_test.onnx");
EXPECT(p == prog);
}
void slice_test()
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 2}});
p.add_instruction(migraphx::op::slice{{0,1}, {1,0}, {2,2}}, l0);
auto prog = migraphx::parse_onnx("slice_test.onnx");
EXPECT(p == prog);
}
void constant_test()
{
migraphx::program p;
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type, {3}}, {0, 1, 2}});
auto prog = migraphx::parse_onnx("constant_test.onnx");
EXPECT(p == prog);
}
int main()
{
pytorch_conv_bias_test();
......@@ -394,4 +448,9 @@ int main()
implicit_bcast_test();
unknown_test();
reshape_test();
flatten_test();
squeeze_unsqueeze_test();
concat_test();
slice_test();
constant_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment