Commit c5c9f667 authored by Khalique's avatar Khalique
Browse files

initial implementation of stridedslice

parent dc85aa6b
...@@ -123,6 +123,7 @@ struct tf_parser ...@@ -123,6 +123,7 @@ struct tf_parser
add_mem_op("Reshape", &tf_parser::parse_reshape); add_mem_op("Reshape", &tf_parser::parse_reshape);
add_mem_op("Softmax", &tf_parser::parse_softmax); add_mem_op("Softmax", &tf_parser::parse_softmax);
add_mem_op("Squeeze", &tf_parser::parse_squeeze); add_mem_op("Squeeze", &tf_parser::parse_squeeze);
add_mem_op("StridedSlice", &tf_parser::parse_stridedslice);
} }
template <class F> template <class F>
...@@ -480,6 +481,32 @@ struct tf_parser ...@@ -480,6 +481,32 @@ struct tf_parser
return prog.add_instruction(op, args[0]); return prog.add_instruction(op, args[0]);
} }
instruction_ref parse_stridedslice(const std::string&, const attribute_map& attributes, std::vector<instruction_ref> args)
{
op::slice op;
auto begin = args[1]->eval().get<int64_t>().to_vector();;
auto end = args[2]->eval().get<int64_t>().to_vector();;
op.starts = begin;
op.ends = end;
int shrink_axis_mask = 0;
std::vector<int64_t> squeeze_axes;
if(contains(attributes, "shrink_axis_mask"))
shrink_axis_mask = attributes.at("shrink_axis_mask").i();
size_t num_axes = args[0]->get_shape().lens().size();
for (size_t i = 0; i < num_axes; i++)
{
if ((shrink_axis_mask >> i) & 1)
squeeze_axes.push_back(i);
}
auto l0 = prog.add_instruction(op, args[0]);
return prog.add_instruction(op::squeeze{squeeze_axes}, l0);
}
void parse_graph(const tensorflow::GraphDef& graph) void parse_graph(const tensorflow::GraphDef& graph)
{ {
nodes = get_nodes(graph, input_nodes); nodes = get_nodes(graph, input_nodes);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment