Unverified Commit 9ca0fbf1 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #100 from ROCmSoftwarePlatform/globalavgpool

added globalavgpool and globalmaxpool with tests
parents 0d0778b7 ea05b277
...@@ -62,6 +62,8 @@ struct onnx_parser ...@@ -62,6 +62,8 @@ struct onnx_parser
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("AveragePool", &onnx_parser::parse_pooling); add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape); add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten); add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
...@@ -148,7 +150,12 @@ struct onnx_parser ...@@ -148,7 +150,12 @@ struct onnx_parser
attribute_map attributes, attribute_map attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::pooling op{name == "MaxPool" ? "max" : "average"}; op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
if(starts_with(name, "Global"))
{
auto lens = args.front()->get_shape().lens();
op.lengths = {lens[2], lens[3]};
}
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
copy(attributes["pads"].ints(), op.padding.begin()); copy(attributes["pads"].ints(), op.padding.begin());
...@@ -584,10 +591,15 @@ struct onnx_parser ...@@ -584,10 +591,15 @@ struct onnx_parser
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(), std::transform(
tensor_dims.end(), tensor_dims.begin(), tensor_dims.end(), std::back_inserter(dims), [](auto&& d) {
std::back_inserter(dims), if(not d.has_dim_value())
[](auto&& d) { return d.dim_value(); }); {
long default_batch_size = 1; // FIXME
return default_batch_size;
}
return d.dim_value();
});
return {shape_type, dims}; return {shape_type, dims};
} }
}; };
......
...@@ -160,6 +160,46 @@ void unsqueeze_test() ...@@ -160,6 +160,46 @@ void unsqueeze_test()
} }
} }
void globalavgpool_test()
{
migraph::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}};
auto op = migraph::op::pooling{"average"};
auto lens = s.lens();
op.lengths = {lens[2], lens[3]};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.575, 0.375};
EXPECT(migraph::verify_range(results_vector, gold));
}
void globalmaxpool_test()
{
migraph::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}};
auto op = migraph::op::pooling{"max"};
auto lens = s.lens();
op.lengths = {lens[2], lens[3]};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraph::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.9, 0.7};
EXPECT(migraph::verify_range(results_vector, gold));
}
void im2col_3x3_no_pad_identity_test() void im2col_3x3_no_pad_identity_test()
{ {
std::size_t f[2] = {3, 3}; std::size_t f[2] = {3, 3};
...@@ -1058,6 +1098,8 @@ int main() ...@@ -1058,6 +1098,8 @@ int main()
conv2d_padding_test(); conv2d_padding_test();
conv2d_padding_stride_test(); conv2d_padding_stride_test();
batch_norm_inference_test(); batch_norm_inference_test();
globalavgpool_test();
globalmaxpool_test();
im2col_3x3_no_pad_identity_test(); im2col_3x3_no_pad_identity_test();
im2col_3x3_no_pad_test(); im2col_3x3_no_pad_test();
im2col_3x3_stride_2_no_pad_test(); im2col_3x3_stride_2_no_pad_test();
......
...@@ -451,6 +451,36 @@ struct test_conv_pooling ...@@ -451,6 +451,36 @@ struct test_conv_pooling
} }
}; };
struct test_global_avg_pooling
{
migraph::program create_program() const
{
migraph::program p;
auto input =
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"average"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input);
return p;
}
};
struct test_global_max_pooling
{
migraph::program create_program() const
{
migraph::program p;
auto input =
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"max"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input);
return p;
}
};
struct test_gemm struct test_gemm
{ {
migraph::program create_program() const migraph::program create_program() const
...@@ -728,6 +758,8 @@ int main() ...@@ -728,6 +758,8 @@ int main()
verify_program<test_add_relu>(); verify_program<test_add_relu>();
verify_program<test_leaky_relu>(); verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>(); verify_program<test_conv_pooling>();
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>(); verify_program<test_gemm>();
// verify_program<test_gemm_ld>(); // verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>(); verify_program<test_gemm_transposeb>();
......
globalavgpool-example:i

01"GlobalAveragePooltest-globalavgpoolZ
0




b
1




B
\ No newline at end of file
globalmaxpool-example:e

01" GlobalMaxPooltest-globalmaxpoolZ
0




b
1




B
\ No newline at end of file
...@@ -118,6 +118,34 @@ void imagescaler_test() ...@@ -118,6 +118,34 @@ void imagescaler_test()
EXPECT(p == prog); EXPECT(p == prog);
} }
void globalavgpool_test()
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"average"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input);
auto prog = migraph::parse_onnx("globalavgpool_test.onnx");
EXPECT(p == prog);
}
void globalmaxpool_test()
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"max"};
auto lens = input->get_shape().lens();
op.lengths = {lens[2], lens[3]};
p.add_instruction(op, input);
auto prog = migraph::parse_onnx("globalmaxpool_test.onnx");
EXPECT(p == prog);
}
int main() int main()
{ {
pytorch_conv_bias_test(); pytorch_conv_bias_test();
...@@ -126,4 +154,6 @@ int main() ...@@ -126,4 +154,6 @@ int main()
pytorch_conv_relu_maxpool_x2(); pytorch_conv_relu_maxpool_x2();
leaky_relu_test(); leaky_relu_test();
imagescaler_test(); imagescaler_test();
globalavgpool_test();
globalmaxpool_test();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment