Commit 71c777bd authored by Scott Thornton's avatar Scott Thornton
Browse files

Changes for being able to read ONNX from PyTorch. Still issue with Reshape

parent 8ae3ffea
......@@ -11,6 +11,7 @@
#include <migraph/program.hpp>
#include <migraph/operators.hpp>
#include <migraph/ranges.hpp>
#include <migraph/instruction.hpp>
namespace migraph {
......@@ -63,6 +64,13 @@ struct onnx_parser
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
if(args.size() == 3)
{
uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(broadcast{axis}, l1, args[2]);
return prog.add_instruction(add{}, l1, l2);
}
return prog.add_instruction(op, args);
});
add_op("MatMul", [this](attribute_map, std::vector<instruction_ref> args) {
......@@ -90,9 +98,17 @@ struct onnx_parser
});
add_op("Reshape", [this](attribute_map attributes, std::vector<instruction_ref> args) {
reshape op;
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args);
if(args.size() == 1)
{
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
if(args.size() == 2)
{
literal s = args[1]->lit;
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
});
add_op("Constant", [this](attribute_map attributes, std::vector<instruction_ref>) {
literal v = parse_value(attributes.at("value"));
......@@ -194,7 +210,7 @@ struct onnx_parser
}
for(auto&& p : nodes)
{
this->parse_node(p.second.name());
this->parse_node(get_name(p.second));
}
}
......@@ -210,7 +226,7 @@ struct onnx_parser
{
if(nodes.count(input) > 0)
{
auto&& iname = nodes.at(input).name();
auto&& iname = get_name(nodes.at(input));
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(iname));
......@@ -241,12 +257,26 @@ struct onnx_parser
return result;
}
static std::string get_name(const onnx::NodeProto& node)
{
if(node.name().empty())
{
std::string generated = "migraph_unnamed_node";
for(auto&& output : node.output())
{
generated += "_" + output;
}
return generated;
}
return node.name();
}
static node_map get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node : graph.node())
{
result[node.name()] = node;
result[get_name(node)] = node;
for(auto&& output : node.output())
{
result[output] = node;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment