Unverified Commit 3b284e9b authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #258 from ROCmSoftwarePlatform/parse_mean_fix

Parse mean fix
parents eeb5bad1 bae3b61b
...@@ -53,15 +53,16 @@ struct tf_parser ...@@ -53,15 +53,16 @@ struct tf_parser
template <class T> template <class T>
std::vector<T> parse_axes(std::vector<T> axes) const std::vector<T> parse_axes(std::vector<T> axes) const
{ {
std::vector<T> new_axes;
if(is_nhwc) if(is_nhwc)
{ {
std::vector<T> new_axes;
std::transform(axes.begin(), std::transform(axes.begin(),
axes.end(), axes.end(),
std::back_inserter(new_axes), std::back_inserter(new_axes),
[&](size_t axis) { return parse_axis(axis); }); [&](size_t axis) { return parse_axis(axis); });
return new_axes;
} }
return new_axes; return axes;
} }
// tf stores certain attributes such as strides, dilations, as a 4D input. // tf stores certain attributes such as strides, dilations, as a 4D input.
...@@ -426,17 +427,21 @@ struct tf_parser ...@@ -426,17 +427,21 @@ struct tf_parser
instruction_ref instruction_ref
parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_mean(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector()); auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; std::vector<int32_t> hw_axes{2, 3};
if(axes == hw_axes and keep_dims) // check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens();
if(axes == hw_axes and lens.size() == 4)
{ {
op::pooling op{"average"}; op::pooling op{"average"};
std::vector<size_t> input_dims{args[0]->get_shape().lens()}; op.lengths[0] = lens[2];
op.lengths[0] = input_dims[2]; op.lengths[1] = lens[3];
op.lengths[1] = input_dims[3]; auto l0 = prog.add_instruction(op, args.front());
return prog.add_instruction(op, args.front()); if(keep_dims)
return l0;
return prog.add_instruction(
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0);
} }
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation"); MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
......
...@@ -168,6 +168,40 @@ TEST_CASE(matmul_test) ...@@ -168,6 +168,40 @@ TEST_CASE(matmul_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(mean_test)
{
migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op;
op.lengths = {16, 16};
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
p.add_instruction(op, l0);
auto prog = migraphx::parse_tf("mean_test.pb", false);
EXPECT(p == prog);
}
TEST_CASE(mean_test_nhwc)
{
migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op;
op.lengths = {16, 16};
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
p.add_instruction(op, l0);
auto prog = migraphx::parse_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog);
}
TEST_CASE(mul_test) TEST_CASE(mul_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment