Commit a5bebb21 authored by Khalique's avatar Khalique
Browse files

add lens variable to simplify code

parent 31436830
...@@ -431,12 +431,12 @@ struct tf_parser ...@@ -431,12 +431,12 @@ struct tf_parser
bool keep_dims = attributes.at("keep_dims").b(); bool keep_dims = attributes.at("keep_dims").b();
std::vector<int32_t> hw_axes{2, 3}; std::vector<int32_t> hw_axes{2, 3};
// check if conditions for GlobalAvgPool are met // check if conditions for GlobalAvgPool are met
if(axes == hw_axes and args[0]->get_shape().lens().size() == 4) auto lens = args[0]->get_shape().lens();
if(axes == hw_axes and lens.size() == 4)
{ {
op::pooling op{"average"}; op::pooling op{"average"};
std::vector<size_t> input_dims{args[0]->get_shape().lens()}; op.lengths[0] = lens[2];
op.lengths[0] = input_dims[2]; op.lengths[1] = lens[3];
op.lengths[1] = input_dims[3];
auto l0 = prog.add_instruction(op, args.front()); auto l0 = prog.add_instruction(op, args.front());
if(keep_dims) if(keep_dims)
return l0; return l0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment