Unverified Commit acadd646 authored by Scott Thornton's avatar Scott Thornton Committed by GitHub
Browse files

Merge pull request #93 from ROCmSoftwarePlatform/onnx_parsing_squeeze_slice_concat

Onnx parsing squeeze slice concat
parents ad414ba9 1a0abc65
......@@ -66,6 +66,10 @@ struct onnx_parser
add_mem_op("Gemm", &onnx_parser::parse_gemm);
add_mem_op("BatchNormalization", &onnx_parser::parse_batchnorm);
add_mem_op("Softmax", &onnx_parser::parse_softmax);
add_mem_op("Squeeze", &onnx_parser::parse_squeeze);
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat);
}
template <class F>
......@@ -187,6 +191,52 @@ struct onnx_parser
return prog.add_instruction(op::flatten{axis}, args[0]);
}
instruction_ref
parse_squeeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::squeeze op;
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_unsqueeze(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::unsqueeze op;
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
return prog.add_instruction(op, args[0]);
}
instruction_ref
parse_concat(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::size_t axis = parse_value(attributes.at("axis")).at<int>();
op::concat op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
op::slice op;
if(contains(attributes, "axes"))
{
literal s = parse_value(attributes.at("axes"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.axes)); });
}
{
literal s = parse_value(attributes.at("ends"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.ends)); });
}
{
literal s = parse_value(attributes.at("starts"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.starts)); });
}
return prog.add_instruction(op, args[0]);
}
instruction_ref parse_constant(const std::string&,
attribute_map attributes,
const std::vector<instruction_ref>&)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment