Commit 71c777bd authored by Scott Thornton's avatar Scott Thornton
Browse files

Changes for being able to read ONNX from PyTorch. Still issue with Reshape

parent 8ae3ffea
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraph/program.hpp> #include <migraph/program.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
#include <migraph/ranges.hpp> #include <migraph/ranges.hpp>
#include <migraph/instruction.hpp>
namespace migraph { namespace migraph {
...@@ -63,6 +64,13 @@ struct onnx_parser ...@@ -63,6 +64,13 @@ struct onnx_parser
{ {
copy(attributes["dilations"].ints(), op.dilation.begin()); copy(attributes["dilations"].ints(), op.dilation.begin());
} }
if(args.size() == 3)
{
uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(broadcast{axis}, l1, args[2]);
return prog.add_instruction(add{}, l1, l2);
}
return prog.add_instruction(op, args); return prog.add_instruction(op, args);
}); });
add_op("MatMul", [this](attribute_map, std::vector<instruction_ref> args) { add_op("MatMul", [this](attribute_map, std::vector<instruction_ref> args) {
...@@ -90,9 +98,17 @@ struct onnx_parser ...@@ -90,9 +98,17 @@ struct onnx_parser
}); });
add_op("Reshape", [this](attribute_map attributes, std::vector<instruction_ref> args) { add_op("Reshape", [this](attribute_map attributes, std::vector<instruction_ref> args) {
reshape op; reshape op;
if(args.size() == 1)
{
literal s = parse_value(attributes.at("shape")); literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args); }
if(args.size() == 2)
{
literal s = args[1]->lit;
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
}); });
add_op("Constant", [this](attribute_map attributes, std::vector<instruction_ref>) { add_op("Constant", [this](attribute_map attributes, std::vector<instruction_ref>) {
literal v = parse_value(attributes.at("value")); literal v = parse_value(attributes.at("value"));
...@@ -194,7 +210,7 @@ struct onnx_parser ...@@ -194,7 +210,7 @@ struct onnx_parser
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
this->parse_node(p.second.name()); this->parse_node(get_name(p.second));
} }
} }
...@@ -210,7 +226,7 @@ struct onnx_parser ...@@ -210,7 +226,7 @@ struct onnx_parser
{ {
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = nodes.at(input).name(); auto&& iname = get_name(nodes.at(input));
assert(name != iname); assert(name != iname);
this->parse_node(iname); this->parse_node(iname);
args.push_back(instructions.at(iname)); args.push_back(instructions.at(iname));
...@@ -241,12 +257,26 @@ struct onnx_parser ...@@ -241,12 +257,26 @@ struct onnx_parser
return result; return result;
} }
static std::string get_name(const onnx::NodeProto& node)
{
if(node.name().empty())
{
std::string generated = "migraph_unnamed_node";
for(auto&& output : node.output())
{
generated += "_" + output;
}
return generated;
}
return node.name();
}
static node_map get_nodes(const onnx::GraphProto& graph) static node_map get_nodes(const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, onnx::NodeProto> result; std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
result[node.name()] = node; result[get_name(node)] = node;
for(auto&& output : node.output()) for(auto&& output : node.output())
{ {
result[output] = node; result[output] = node;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment