Commit 523eabeb authored by Khalique's avatar Khalique
Browse files

added fixes to reading strided slice

parent e2b588c0
...@@ -219,8 +219,9 @@ void simplify_reshapes::apply(program& p) const ...@@ -219,8 +219,9 @@ void simplify_reshapes::apply(program& p) const
ins, ins,
find_nop_reshapes{}, find_nop_reshapes{},
find_reshaper{}, find_reshaper{},
find_transpose{}, find_transpose{}
find_concat_transpose{}); // find_concat_transpose{}
);
} }
} }
......
...@@ -174,7 +174,7 @@ struct tf_parser ...@@ -174,7 +174,7 @@ struct tf_parser
add_mem_op("Reshape", &tf_parser::parse_reshape, false); add_mem_op("Reshape", &tf_parser::parse_reshape, false);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze, false); add_mem_op("Squeeze", &tf_parser::parse_squeeze, false);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice); add_mem_op("StridedSlice", &tf_parser::parse_stridedslice, false);
} }
template <class F> template <class F>
...@@ -717,7 +717,9 @@ struct tf_parser ...@@ -717,7 +717,9 @@ struct tf_parser
op::slice op; op::slice op;
auto starts = args[1]->eval().get<int32_t>().to_vector(); auto starts = args[1]->eval().get<int32_t>().to_vector();
auto ends = args[2]->eval().get<int32_t>().to_vector(); auto ends = args[2]->eval().get<int32_t>().to_vector();
size_t num_axes = args[0]->get_shape().lens().size(); auto l0 = args[0];
size_t num_axes = l0->get_shape().lens().size();
std::vector<size_t> axes = l0->get_shape().lens();
op.starts = std::vector<int64_t>(starts.begin(), starts.end()); op.starts = std::vector<int64_t>(starts.begin(), starts.end());
op.ends = std::vector<int64_t>(ends.begin(), ends.end()); op.ends = std::vector<int64_t>(ends.begin(), ends.end());
...@@ -758,12 +760,6 @@ struct tf_parser ...@@ -758,12 +760,6 @@ struct tf_parser
end_axes.push_back(0); end_axes.push_back(0);
} }
if(num_axes >= 4)
{
reorder_data(begin_axes);
reorder_data(end_axes);
}
for(size_t i = 0; i < num_axes; i++) for(size_t i = 0; i < num_axes; i++)
{ {
if(begin_axes.at(i) == 1) if(begin_axes.at(i) == 1)
...@@ -776,9 +772,9 @@ struct tf_parser ...@@ -776,9 +772,9 @@ struct tf_parser
} }
} }
auto l0 = prog.add_instruction(op, args[0]); auto l1 = prog.add_instruction(op, l0);
if(shrink_axis_mask == 0) if(shrink_axis_mask == 0)
return l0; return l1;
for(size_t i = 0; i < num_axes; i++) for(size_t i = 0; i < num_axes; i++)
{ {
...@@ -787,8 +783,7 @@ struct tf_parser ...@@ -787,8 +783,7 @@ struct tf_parser
squeeze_axes.push_back(i); squeeze_axes.push_back(i);
} }
auto l0 = prog.add_instruction(op, make_contiguous(args[0])); return prog.add_instruction(op::squeeze{squeeze_axes}, l1);
return to_nhwc(prog.add_instruction(op::squeeze{squeeze_axes}, l0));
} }
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment