Unverified Commit b7218806 authored by turneram's avatar turneram Committed by GitHub
Browse files

Add Mean op ONNX parser (#1065)

* Add mean op onnx parser and unit tests
* Refactor parse_mean to use add_broadcastable_binary_op
parent 332cb710
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_mean : op_parser<parse_mean>
{
std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
auto num_data = args.size();
if(num_data == 1)
return args[0];
auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
return std::accumulate(args.begin(), args.end(), args[0], [&](auto& mean, auto& data_i) {
// Pre-divide each tensor element-wise by n to reduce risk of overflow during summation
data_i = info.add_broadcastable_binary_op("div", data_i, divisor);
if(data_i != args[0])
return info.add_broadcastable_binary_op("add", mean, data_i);
return data_i;
});
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -2762,6 +2762,80 @@ def maxpool_same_upper_test(): ...@@ -2762,6 +2762,80 @@ def maxpool_same_upper_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def mean_broadcast_test():
data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 4])
data_1 = helper.make_tensor_value_info('1', TensorProto.FLOAT,
[1, 2, 3, 4])
data_2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [4])
data_3 = helper.make_tensor_value_info('3', TensorProto.FLOAT, [1])
data_4 = helper.make_tensor_value_info('4', TensorProto.FLOAT, [2, 3, 1])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT,
[1, 2, 3, 4])
node = onnx.helper.make_node("Mean",
inputs=["0", "1", "2", "3", "4"],
outputs=["mean"])
return ([node], [data_0, data_1, data_2, data_3, data_4], [mean])
@onnx_test
def mean_fp16_test():
data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT16, [1, 2, 3])
data_1 = helper.make_tensor_value_info('1', TensorProto.FLOAT16, [1, 2, 3])
data_2 = helper.make_tensor_value_info('2', TensorProto.FLOAT16, [1, 2, 3])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT16,
[1, 2, 3])
node = onnx.helper.make_node("Mean",
inputs=["0", "1", "2"],
outputs=["mean"])
return ([node], [data_0, data_1, data_2], [mean])
@onnx_test
def mean_invalid_broadcast_test():
data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3])
data_1 = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 2, 3])
data_2 = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 2, 4])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [1, 2, 3])
node = onnx.helper.make_node("Mean",
inputs=["0", "1", "2"],
outputs=["mean"])
return ([node], [data_0, data_1, data_2], [mean])
@onnx_test
def mean_single_input_test():
data_0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 2, 3])
mean = helper.make_tensor_value_info('mean', TensorProto.FLOAT, [1, 2, 3])
node = onnx.helper.make_node("Mean", inputs=["0"], outputs=["mean"])
return ([node], [data_0], [mean])
@onnx_test
def mean_test():
data = [
helper.make_tensor_value_info(str(i), TensorProto.DOUBLE, [2, 2, 2])
for i in range(10)
]
data_names = [str(i) for i in range(10)]
mean = helper.make_tensor_value_info('mean', TensorProto.DOUBLE, [2, 2, 2])
node = onnx.helper.make_node("Mean", inputs=data_names, outputs=["mean"])
return ([node], data, [mean])
@onnx_test @onnx_test
def min_test(): def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
mean_broadcast_test:Ã

0
1
2
3
4mean"Meanmean_broadcast_testZ
0



Z
1




Z
2

Z
3

Z
4



b
mean




B
\ No newline at end of file
mean_fp16_test:Ž

0
1
2mean"Meanmean_fp16_testZ
0




Z
1




Z
2




b
mean




B
\ No newline at end of file
mean_invalid_broadcast_test:›

0
1
2mean"Meanmean_invalid_broadcast_testZ
0



Z
1



Z
2



b
mean



B
\ No newline at end of file
mean_single_input_test:^

0mean"Meanmean_single_input_testZ
0



b
mean



B
\ No newline at end of file
 mean_test:Í
*
0
1
2
3
4
5
6
7
8
9mean"Mean mean_testZ
0
 


Z
1
 


Z
2
 


Z
3
 


Z
4
 


Z
5
 


Z
6
 


Z
7
 


Z
8
 


Z
9
 


b
mean
 


B
\ No newline at end of file
...@@ -2492,6 +2492,50 @@ TEST_CASE(maxpool_same_upper_test) ...@@ -2492,6 +2492,50 @@ TEST_CASE(maxpool_same_upper_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(mean_invalid_broadcast_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("mean_invalid_broadcast_test.onnx"); }));
}
TEST_CASE(mean_single_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3}});
mm->add_return({data0});
auto prog = migraphx::parse_onnx("mean_single_input_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(mean_test)
{
const std::size_t num_data = 3;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::half_type, {1, 2, 3}};
auto data0 = mm->add_parameter("0", s);
auto data1 = mm->add_parameter("1", s);
auto data2 = mm->add_parameter("2", s);
auto div_lit = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {num_data}});
auto divisor =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), div_lit);
auto mean = mm->add_instruction(migraphx::make_op("div"), data0, divisor);
divisor =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), div_lit);
data1 = mm->add_instruction(migraphx::make_op("div"), data1, divisor);
mean = mm->add_instruction(migraphx::make_op("add"), mean, data1);
divisor =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), div_lit);
data2 = mm->add_instruction(migraphx::make_op("div"), data2, divisor);
mean = mm->add_instruction(migraphx::make_op("add"), mean, data2);
auto prog = optimize_onnx("mean_fp16_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(min_test) TEST_CASE(min_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -393,6 +393,64 @@ TEST_CASE(lessorequal_test) ...@@ -393,6 +393,64 @@ TEST_CASE(lessorequal_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(mean_broadcast_test)
{
migraphx::program p = migraphx::parse_onnx("mean_broadcast_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s0{migraphx::shape::float_type, {1, 3, 4}};
std::vector<float> data0(12, 1);
migraphx::shape s1{migraphx::shape::float_type, {1, 2, 3, 4}};
std::vector<float> data1(24, 2);
migraphx::shape s2{migraphx::shape::float_type, {4}};
std::vector<float> data2(4, 3);
migraphx::shape s3{migraphx::shape::float_type, {1}};
std::vector<float> data3(1, 4);
migraphx::shape s4{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> data4(6, 5);
migraphx::parameter_map pp;
pp["0"] = migraphx::argument(s0, data0.data());
pp["1"] = migraphx::argument(s1, data1.data());
pp["2"] = migraphx::argument(s2, data2.data());
pp["3"] = migraphx::argument(s3, data3.data());
pp["4"] = migraphx::argument(s4, data4.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(24, 3);
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(mean_test)
{
migraphx::program p = migraphx::parse_onnx("mean_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::double_type, {2, 2, 2}};
const int num_elms = 8;
const int num_data = 10;
const std::vector<double> scalars{1.0, 2.0, -2.5, 3.3, 10.7, -1.0, 100.0, 7.9, 0.01, -56.8};
std::vector<std::vector<double>> data;
std::transform(scalars.begin(), scalars.end(), std::back_inserter(data), [&](const auto& i) {
return std::vector<double>(num_elms, i);
});
migraphx::parameter_map pp;
for(std::size_t i = 0; i < num_data; ++i)
pp[std::to_string(i)] = migraphx::argument(s, data[i].data());
auto result = p.eval(pp).back();
std::vector<double> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0.0) / num_data;
std::vector<double> gold(num_elms, mean);
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(nonzero_test) TEST_CASE(nonzero_test)
{ {
migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx"); migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx");
......
...@@ -269,9 +269,6 @@ def create_backend_test(testname=None, target_device=None): ...@@ -269,9 +269,6 @@ def create_backend_test(testname=None, target_device=None):
backend_test.exclude(r'test_gathernd_example_int32_cpu') backend_test.exclude(r'test_gathernd_example_int32_cpu')
backend_test.exclude(r'test_identity_sequence_cpu') backend_test.exclude(r'test_identity_sequence_cpu')
backend_test.exclude(r'test_maxpool_2d_uint8_cpu') backend_test.exclude(r'test_maxpool_2d_uint8_cpu')
backend_test.exclude(r'test_mean_example_cpu')
backend_test.exclude(r'test_mean_one_input_cpu')
backend_test.exclude(r'test_mean_two_inputs_cpu')
backend_test.exclude(r'test_negative_log_likelihood_loss_*') backend_test.exclude(r'test_negative_log_likelihood_loss_*')
backend_test.exclude(r'test_scatternd_*') backend_test.exclude(r'test_scatternd_*')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment