Commit a3a89d67 authored by Khalique's avatar Khalique
Browse files

added globalavgpool and globalmaxpool with tests

parent 3d57cfed
......@@ -182,7 +182,7 @@ struct pooling
std::string mode = "average";
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}};
std::vector<std::size_t> lengths = {{1, 1}};
template <class Self, class F>
static auto reflect(Self& self, F f)
......
......@@ -62,6 +62,8 @@ struct onnx_parser
add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling);
add_mem_op("AveragePool", &onnx_parser::parse_pooling);
add_mem_op("GlobalMaxPool", &onnx_parser::parse_pooling);
add_mem_op("GlobalAveragePool", &onnx_parser::parse_pooling);
add_mem_op("Reshape", &onnx_parser::parse_reshape);
add_mem_op("Flatten", &onnx_parser::parse_flatten);
add_mem_op("Gemm", &onnx_parser::parse_gemm);
......@@ -148,7 +150,15 @@ struct onnx_parser
attribute_map attributes,
std::vector<instruction_ref> args)
{
op::pooling op{name == "MaxPool" ? "max" : "average"};
op::pooling op{name == "MaxPool" or name == "GlobalMaxPool" ? "max" : "average"};
if(name == "GlobalMaxPool" or name == "GlobalAveragePool")
{
auto lens = args.front()->get_shape().lens();
auto num_lengths = lens.size() - 2; // ignore N and C values in lens
assert(num_lengths > 0);
op.lengths = std::vector<std::size_t>(num_lengths);
std::copy_n(lens.begin() + 2, num_lengths, op.lengths.begin());
}
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
......@@ -584,10 +594,15 @@ struct onnx_parser
}
std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(),
tensor_dims.end(),
std::back_inserter(dims),
[](auto&& d) { return d.dim_value(); });
std::transform(
tensor_dims.begin(), tensor_dims.end(), std::back_inserter(dims), [](auto&& d) {
if(not d.has_dim_value())
{
long default_batch_size = 1; // FIXME
return default_batch_size;
}
return d.dim_value();
});
return {shape_type, dims};
}
};
......
......@@ -160,6 +160,46 @@ void unsqueeze_test()
}
}
void globalavgpool_test()
{
migraph::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}};
auto op = migraph::op::pooling{"average"};
auto lens = s.lens();
op.lengths = std::vector<std::size_t>{lens[2], lens[3]};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.25, 0.575, 0.375};
EXPECT(migraph::verify_range(results_vector, gold));
}
void globalmaxpool_test()
{
migraph::program p;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 2, 2}};
auto op = migraph::op::pooling{"max"};
auto lens = s.lens();
op.lengths = std::vector<std::size_t>{lens[2], lens[3]};
std::vector<float> data{0.3, 0.2, 0.4, 0.1, 0.8, 0.5, 0.9, 0.1, 0.1, 0.7, 0.1, 0.6};
auto l0 = p.add_literal(migraph::literal{s, data});
p.add_instruction(op, l0);
p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({});
std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0.4, 0.9, 0.7};
EXPECT(migraph::verify_range(results_vector, gold));
}
void im2col_3x3_no_pad_identity_test()
{
std::size_t f[2] = {3, 3};
......@@ -1058,6 +1098,8 @@ int main()
conv2d_padding_test();
conv2d_padding_stride_test();
batch_norm_inference_test();
globalavgpool_test();
globalmaxpool_test();
im2col_3x3_no_pad_identity_test();
im2col_3x3_no_pad_test();
im2col_3x3_stride_2_no_pad_test();
......
......@@ -423,6 +423,36 @@ struct test_conv_pooling
}
};
struct test_global_avg_pooling
{
migraph::program create_program() const
{
migraph::program p;
auto input =
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"average"};
auto lens = input->get_shape().lens();
op.lengths = std::vector<std::size_t>{lens[2], lens[3]};
p.add_instruction(op, input);
return p;
}
};
struct test_global_max_pooling
{
migraph::program create_program() const
{
migraph::program p;
auto input =
p.add_parameter("x", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"max"};
auto lens = input->get_shape().lens();
op.lengths = std::vector<std::size_t>{lens[2], lens[3]};
p.add_instruction(op, input);
return p;
}
};
struct test_gemm
{
migraph::program create_program() const
......@@ -698,6 +728,8 @@ int main()
verify_program<test_add_relu>();
verify_program<test_leaky_relu>();
verify_program<test_conv_pooling>();
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
......
globalavgpool-example:i

01"GlobalAveragePooltest-globalavgpoolZ
0




b
1




B
\ No newline at end of file
globalmaxpool-example:e

01" GlobalMaxPooltest-globalmaxpoolZ
0




b
1




B
\ No newline at end of file
......@@ -118,6 +118,34 @@ void imagescaler_test()
EXPECT(p == prog);
}
void globalavgpool_test()
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"average"};
auto lens = input->get_shape().lens();
op.lengths = std::vector<std::size_t>{lens[2], lens[3]};
p.add_instruction(op, input);
auto prog = migraph::parse_onnx("globalavgpool_test.onnx");
EXPECT(p == prog);
}
void globalmaxpool_test()
{
migraph::program p;
auto input = p.add_parameter("0", migraph::shape{migraph::shape::float_type, {1, 3, 16, 16}});
auto op = migraph::op::pooling{"max"};
auto lens = input->get_shape().lens();
op.lengths = std::vector<std::size_t>{lens[2], lens[3]};
p.add_instruction(op, input);
auto prog = migraph::parse_onnx("globalmaxpool_test.onnx");
EXPECT(p == prog);
}
int main()
{
pytorch_conv_bias_test();
......@@ -126,4 +154,6 @@ int main()
pytorch_conv_relu_maxpool_x2();
leaky_relu_test();
imagescaler_test();
globalavgpool_test();
globalmaxpool_test();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment