Commit 1af6c18a authored by Scott Thornton's avatar Scott Thornton
Browse files

Added parsing for squeeze, unsqueeze, slice, and concat

parent 11d00d61
...@@ -65,6 +65,10 @@ struct onnx_parser ...@@ -65,6 +65,10 @@ struct onnx_parser
add_mem_op("Gemm", &onnx_parser::parse_gemm); add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm); add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax); add_mem_op("Softmax", &onnx_parser::parse_softmax);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
//add_mem_op("Concat", &onnx_parser::parse_concat);
} }
template <class F> template <class F>
...@@ -186,6 +190,52 @@ struct onnx_parser ...@@ -186,6 +190,52 @@ struct onnx_parser
return prog.add_instruction(op::flatten{axis}, args[0]); return prog.add_instruction(op::flatten{axis}, args[0]);
} }
instruction_ref
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::squeeze op;
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::unsqueeze op;
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
}
// instruction_ref
// parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
// {
// literal axis = parse_value(attributes.at("axis")).at<int>();
// op::concat op{axis};
// return prog.add_instruction(op, std::move(args));
// }
instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::slice op;
if(contains(attributes, "axes"))
{
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
}
{
literal s = parse_value(attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
}
{
literal s = parse_value(attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
return prog.add_instruction(op, args[0]);
}
instruction_ref parse_constant(const std::string&, instruction_ref parse_constant(const std::string&,
attribute_map attributes, attribute_map attributes,
const std::vector<instruction_ref>&) const std::vector<instruction_ref>&)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment