Commit 6c27b962 authored by Khalique's avatar Khalique
Browse files

add capability to reduce dims

parent 767ca0cc
...@@ -372,13 +372,18 @@ struct tf_parser ...@@ -372,13 +372,18 @@ struct tf_parser
auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector()); auto axes = parse_axes(args[1]->eval().get<int32_t>().to_vector());
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; std::vector<int32_t> hw_axes{2, 3};
if(axes == hw_axes and keep_dims) // check if conditions for GlobalAvgPool are met
if(axes == hw_axes and args[0]->get_shape().lens().size() == 4)
{ {
op::pooling op{"average"}; op::pooling op{"average"};
std::vector<size_t> input_dims{args[0]->get_shape().lens()}; std::vector<size_t> input_dims{args[0]->get_shape().lens()};
op.lengths[0] = input_dims[2]; op.lengths[0] = input_dims[2];
op.lengths[1] = input_dims[3]; op.lengths[1] = input_dims[3];
return prog.add_instruction(op, args.front()); auto l0 = prog.add_instruction(op, args.front());
if(keep_dims)
return l0;
return prog.add_instruction(
op::squeeze{std::vector<int64_t>(hw_axes.begin(), hw_axes.end())}, l0);
} }
MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation"); MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment