Commit ea166055 authored by Paul's avatar Paul
Browse files

Add initial support for multi output

parent ecbb4de5
...@@ -24,7 +24,7 @@ struct onnx_parser ...@@ -24,7 +24,7 @@ struct onnx_parser
{ {
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func = std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
...@@ -88,6 +88,15 @@ struct onnx_parser ...@@ -88,6 +88,15 @@ struct onnx_parser
template <class F> template <class F>
void add_op(std::string name, F f) void add_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
});
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{ {
ops.emplace(name, f); ops.emplace(name, f);
} }
...@@ -95,7 +104,7 @@ struct onnx_parser ...@@ -95,7 +104,7 @@ struct onnx_parser
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(std::string name, F f)
{ {
ops.emplace(name, [=](auto&&... xs) { add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
...@@ -103,7 +112,7 @@ struct onnx_parser ...@@ -103,7 +112,7 @@ struct onnx_parser
template <class T> template <class T>
void add_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPHX_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast")) if(contains(attributes, "broadcast"))
...@@ -174,7 +183,7 @@ struct onnx_parser ...@@ -174,7 +183,7 @@ struct onnx_parser
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
...@@ -182,7 +191,7 @@ struct onnx_parser ...@@ -182,7 +191,7 @@ struct onnx_parser
template <class T> template <class T>
void add_variadic_op(std::string name, T x) void add_variadic_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()), return std::accumulate(std::next(args.begin()),
args.end(), args.end(),
args.front(), args.front(),
...@@ -645,7 +654,7 @@ struct onnx_parser ...@@ -645,7 +654,7 @@ struct onnx_parser
} }
else else
{ {
throw std::runtime_error("Failed reading"); MIGRAPHX_THROW("Failed reading onnx file.");
} }
} }
...@@ -691,24 +700,28 @@ struct onnx_parser ...@@ -691,24 +700,28 @@ struct onnx_parser
{ {
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = get_name(nodes.at(input)); // auto&& iname = get_name(nodes.at(input));
assert(name != iname); assert(name != input);
this->parse_node(iname); this->parse_node(input);
args.push_back(instructions.at(iname)); args.push_back(instructions.at(input));
} }
else else
{ {
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
} }
std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
instructions[name] = prog.add_instruction(unknown{node.op_type()}, args); result.push_back(prog.add_instruction(unknown{node.op_type()}, args));
} }
else else
{ {
instructions[name] = ops[node.op_type()](get_attributes(node), args); result = ops[node.op_type()](get_attributes(node), args);
} }
std::transform(node.output().begin(), node.output().end(), result.begin(), std::inserter(instructions, instructions.end()), [](auto&& onnx_out, auto&& parse_out) {
return std::make_pair(onnx_out, parse_out);
});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment