Commit 59b0d9e6 authored by Khalique's avatar Khalique
Browse files

added test for group_conv

parent 09ed8ee2
...@@ -56,7 +56,7 @@ struct miopen_apply ...@@ -56,7 +56,7 @@ struct miopen_apply
void check_shape(shape x, instruction_ref i) void check_shape(shape x, instruction_ref i)
{ {
assert(x.lens() == i->get_shape().lens()); assert(x == i->get_shape());
(void)x; (void)x;
(void)i; (void)i;
} }
......
...@@ -534,6 +534,22 @@ struct test_conv2 ...@@ -534,6 +534,22 @@ struct test_conv2
} }
}; };
struct test_group_conv
{
migraphx::program create_program() const
{
migraphx::program p;
auto input =
p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}});
auto weights =
p.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}});
migraphx::op::convolution op;
op.group = 4;
p.add_instruction(op, input, weights);
return p;
}
};
struct test_conv_relu struct test_conv_relu
{ {
migraphx::program create_program() const migraphx::program create_program() const
...@@ -1034,6 +1050,7 @@ int main() ...@@ -1034,6 +1050,7 @@ int main()
verify_program<test_softmax2>(); verify_program<test_softmax2>();
verify_program<test_conv>(); verify_program<test_conv>();
verify_program<test_conv2>(); verify_program<test_conv2>();
verify_program<test_group_conv>();
verify_program<test_conv_relu>(); verify_program<test_conv_relu>();
verify_program<test_conv_relu_half>(); verify_program<test_conv_relu_half>();
verify_program<test_add_relu>(); verify_program<test_add_relu>();
......
...@@ -484,4 +484,15 @@ TEST_CASE(add_scalar_test) ...@@ -484,4 +484,15 @@ TEST_CASE(add_scalar_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(group_conv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 4, 16, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 1, 3, 3}});
migraphx::op::convolution op;
op.group = 4;
p.add_instruction(op, l0, l1);
auto prog = migraphx::parse_onnx("group_conv_test.onnx");
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment