Commit 1ea07f5a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add onnx support for argmax and argmin

parent 7ccdb25e
...@@ -63,6 +63,8 @@ struct onnx_parser ...@@ -63,6 +63,8 @@ struct onnx_parser
add_variadic_op("Max", op::max{}); add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{}); add_variadic_op("Min", op::min{});
add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Clip", &onnx_parser::parse_clip); add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
...@@ -266,6 +268,33 @@ struct onnx_parser ...@@ -266,6 +268,33 @@ struct onnx_parser
return prog.add_instruction(op::logsoftmax{axis}, std::move(args)); return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
} }
instruction_ref parse_argmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::argmax{axis}, std::move(args));
}
instruction_ref parse_argmin(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
return prog.add_instruction(op::argmin{axis}, std::move(args));
}
instruction_ref instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment