Commit b639892a authored by Paul's avatar Paul
Browse files

Add another test for conv bn

parent d2778c9e
...@@ -439,6 +439,40 @@ struct test_conv_bn_relu_pooling ...@@ -439,6 +439,40 @@ struct test_conv_bn_relu_pooling
} }
}; };
struct test_conv_bn_relu_pooling2
{
static migraph::instruction_ref add_bn(migraph::program& p, migraph::instruction_ref x, std::size_t channels)
{
migraph::shape vars{migraph::shape::float_type, {channels}};
auto scale = p.add_literal(migraph::abs(migraph::generate_literal(vars, 1+channels)));
auto bias = p.add_literal(migraph::abs(migraph::generate_literal(vars, 2+channels)));
auto mean = p.add_literal(migraph::abs(migraph::generate_literal(vars, 3+channels)));
auto variance = p.add_literal(migraph::abs(migraph::generate_literal(vars, 4+channels)));
return p.add_instruction(migraph::batch_norm_inference{}, x, scale, bias, mean, variance);
}
migraph::program create_program() const
{
migraph::program p;
migraph::shape xs1{migraph::shape::float_type, {1, 512, 7, 7}};
migraph::shape xs2{migraph::shape::float_type, {1, 1024, 14, 14}};
migraph::shape ws1{migraph::shape::float_type, {2048, 512, 1, 1}};
migraph::shape ws2{migraph::shape::float_type, {2048, 1024, 1, 1}};
auto x1 = p.add_parameter("x1", xs1);
auto w1 = p.add_parameter("w1", ws1);
auto conv1 = p.add_instruction(migraph::convolution{{0, 0}, {1, 1}, {1, 1}}, x1, w1);
auto bn1 = add_bn(p, conv1, 2048);
auto x2 = p.add_parameter("x2", xs2);
auto w2 = p.add_parameter("w2", ws2);
auto conv2 = p.add_instruction(migraph::convolution{{0, 0}, {2, 2}, {1, 1}}, x2, w2);
auto bn2 = add_bn(p, conv2, 2048);
auto add = p.add_instruction(migraph::add{}, bn1, bn2);
auto relu = p.add_instruction(migraph::activation{"relu"}, add);
p.add_instruction(migraph::pooling{"average", {1, 1}, {2, 2}, {3, 3}}, relu);
return p;
}
};
int main() int main()
{ {
verify_program<test_add>(); verify_program<test_add>();
...@@ -460,4 +494,5 @@ int main() ...@@ -460,4 +494,5 @@ int main()
verify_program<test_batchnorm_inference>(); verify_program<test_batchnorm_inference>();
verify_program<test_batchnorm_inference_2>(); verify_program<test_batchnorm_inference_2>();
verify_program<test_conv_bn_relu_pooling>(); verify_program<test_conv_bn_relu_pooling>();
verify_program<test_conv_bn_relu_pooling2>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment