Commit 08b3f215 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add onnx test for reduce_mean

parent 1644d4fc
......@@ -96,6 +96,7 @@ struct onnx_parser
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
add_mem_op("ReduceMean", &onnx_parser::parse_reduce_mean);
// init the activation function map
init_actv_func();
......@@ -1321,6 +1322,39 @@ struct onnx_parser
}
}
instruction_ref parse_reduce_mean(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<int64_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<int64_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_mean{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_mean{axes}, std::move(args));
return prog.add_instruction(op::squeeze{axes}, ins);
}
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......
......@@ -863,7 +863,7 @@ TEST_CASE(reducemean_test2)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.add_instruction(migraphx::op::reduce_mean{{2}}, l0);
auto prog = migraphx::parse_onnx("reducemean_test2.onnx");
EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment