Commit 708c0401 authored by Khalique's avatar Khalique
Browse files

formatting

parent 0420b20d
...@@ -577,8 +577,8 @@ struct tf_parser ...@@ -577,8 +577,8 @@ struct tf_parser
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
// std::vector<int32_t> hw_axes{2, 3}; // std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met // check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens(); auto lens = args[0]->get_shape().lens();
auto axes = args[1]->eval().get<int32_t>().to_vector(); auto axes = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end()); std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end());
// if(axes == hw_axes and lens.size() == 4) // if(axes == hw_axes and lens.size() == 4)
...@@ -586,11 +586,10 @@ struct tf_parser ...@@ -586,11 +586,10 @@ struct tf_parser
// op::pooling op{"average"}; // op::pooling op{"average"};
// op.lengths[0] = lens[2]; // op.lengths[0] = lens[2];
// op.lengths[1] = lens[3]; // op.lengths[1] = lens[3];
auto l0 = prog.add_instruction(op::reduce_mean{axes_int64}, args.front()); auto l0 = prog.add_instruction(op::reduce_mean{axes_int64}, args.front());
if(keep_dims) if(keep_dims)
return l0; return l0;
return prog.add_instruction( return prog.add_instruction(op::squeeze{axes_int64}, l0);
op::squeeze{axes_int64}, l0);
// } // }
// MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation"); // MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
} }
...@@ -770,7 +769,7 @@ struct tf_parser ...@@ -770,7 +769,7 @@ struct tf_parser
const attribute_map& attributes, const attribute_map& attributes,
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
int axis = -1; int axis = -1;
auto num_dims = args[0]->get_shape().lens().size(); auto num_dims = args[0]->get_shape().lens().size();
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment