Commit 708c0401 authored by Khalique's avatar Khalique
Browse files

formatting

parent 0420b20d
......@@ -577,8 +577,8 @@ struct tf_parser
bool keep_dims = attributes.at("keep_dims").b();
// std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met
auto lens = args[0]->get_shape().lens();
auto axes = args[1]->eval().get<int32_t>().to_vector();
auto lens = args[0]->get_shape().lens();
auto axes = args[1]->eval().get<int32_t>().to_vector();
std::vector<int64_t> axes_int64 = std::vector<int64_t>(axes.begin(), axes.end());
// if(axes == hw_axes and lens.size() == 4)
......@@ -586,11 +586,10 @@ struct tf_parser
// op::pooling op{"average"};
// op.lengths[0] = lens[2];
// op.lengths[1] = lens[3];
auto l0 = prog.add_instruction(op::reduce_mean{axes_int64}, args.front());
auto l0 = prog.add_instruction(op::reduce_mean{axes_int64}, args.front());
if(keep_dims)
return l0;
return prog.add_instruction(
op::squeeze{axes_int64}, l0);
return prog.add_instruction(op::squeeze{axes_int64}, l0);
// }
// MIGRAPHX_THROW("MIGraphX does not support mean outside of GlobalAvgPool transformation");
}
......@@ -770,7 +769,7 @@ struct tf_parser
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = -1;
int axis = -1;
auto num_dims = args[0]->get_shape().lens().size();
if(contains(attributes, "axis"))
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment