Unverified Commit d895104a authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Fix onnx mean parsing for integral inputs (#1209)

As described in #1196, the ONNX mean parser does not work correctly for integral types. This update fixes the issue by handling integral types separately, where summation is performed before division. Additional test cases have also been added for handling integral types.
parent 4a312201
......@@ -2,6 +2,7 @@
#include <migraphx/onnx/checks.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -9,6 +10,9 @@ namespace onnx {
struct parse_mean : op_parser<parse_mean>
{
const std::set<shape::type_t> float_types = {
shape::float_type, shape::half_type, shape::double_type};
std::vector<op_desc> operators() const { return {{"Mean"}}; }
/// Calculates the element-wise mean of n>=1 input tensors
......@@ -24,7 +28,8 @@ struct parse_mean : op_parser<parse_mean>
auto divisor = info.add_literal(
migraphx::literal{migraphx::shape{args[0]->get_shape().type()}, {num_data}});
// TODO: Only divide when using floating-point
if(contains(float_types, args[0]->get_shape().type()))
{
return std::accumulate(args.begin() + 1,
args.end(),
info.add_broadcastable_binary_op("div", args[0], divisor),
......@@ -36,6 +41,17 @@ struct parse_mean : op_parser<parse_mean>
return info.add_broadcastable_binary_op("add", mean, div);
});
}
else
{
// Compute sum before division for integral types
auto sum = std::accumulate(
args.begin() + 1, args.end(), args[0], [&](auto accum, auto data_i) {
return info.add_broadcastable_binary_op("add", accum, data_i);
});
return info.add_broadcastable_binary_op("div", sum, divisor);
}
}
};
} // namespace onnx
......
......@@ -3178,6 +3178,20 @@ def mean_test():
return ([node], data, [mean])
@onnx_test
def mean_integral_test():
data = [
helper.make_tensor_value_info(str(i), TensorProto.INT32, [2, 2, 2])
for i in range(10)
]
data_names = [str(i) for i in range(10)]
mean = helper.make_tensor_value_info('mean', TensorProto.INT32, [2, 2, 2])
node = onnx.helper.make_node("Mean", inputs=data_names, outputs=["mean"])
return ([node], data, [mean])
@onnx_test
def min_test():
a = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
mean_integral_test:Ö
*
0
1
2
3
4
5
6
7
8
9mean"Meanmean_integral_testZ
0



Z
1



Z
2



Z
3



Z
4



Z
5



Z
6



Z
7



Z
8



Z
9



b
mean



B
\ No newline at end of file
......@@ -2890,6 +2890,30 @@ TEST_CASE(mean_test)
EXPECT(p == prog);
}
TEST_CASE(mean_integral_test)
{
const std::size_t num_data = 10;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::int32_type, {2, 2, 2}};
auto mean = mm->add_parameter("0", s);
for(std::size_t i = 1; i < num_data; ++i)
{
auto data = mm->add_parameter(std::to_string(i), s);
mean = mm->add_instruction(migraphx::make_op("add"), mean, data);
}
auto div_lit = mm->add_literal(migraphx::literal{migraphx::shape{s.type()}, {num_data}});
auto divisor =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), div_lit);
mean = mm->add_instruction(migraphx::make_op("div"), mean, divisor);
auto prog = optimize_onnx("mean_integral_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(min_test)
{
migraphx::program p;
......
......@@ -581,6 +581,33 @@ TEST_CASE(mean_test)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(mean_integral_test)
{
migraphx::program p = migraphx::parse_onnx("mean_integral_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s{migraphx::shape::int32_type, {2, 2, 2}};
const int num_elms = 8;
const int num_data = 10;
const std::vector<int> scalars{1, 5, 14, 2, 6, 21, 101, 0, -4, -11};
std::vector<std::vector<int>> data;
std::transform(scalars.begin(), scalars.end(), std::back_inserter(data), [&](const auto i) {
return std::vector<int>(num_elms, i);
});
migraphx::parameter_map pp;
for(std::size_t i = 0; i < num_data; ++i)
pp[std::to_string(i)] = migraphx::argument(s, data[i].data());
auto result = p.eval(pp).back();
std::vector<double> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
const auto mean = std::accumulate(scalars.begin(), scalars.end(), 0) / num_data;
std::vector<int> gold(num_elms, mean);
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(nonzero_test)
{
migraphx::program p = migraphx::parse_onnx("nonzero_dynamic_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment