"vscode:/vscode.git/clone" did not exist on "24a6fec2d45a835d4ea85eb0486ada208503b490"
Commit 71c777bd authored by Scott Thornton's avatar Scott Thornton
Browse files

Changes for being able to read ONNX from PyTorch. Still issue with Reshape

parent 8ae3ffea
......@@ -11,6 +11,7 @@
#include <migraph/program.hpp>
#include <migraph/operators.hpp>
#include <migraph/ranges.hpp>
#include <migraph/instruction.hpp>
namespace migraph {
......@@ -63,6 +64,13 @@ struct onnx_parser
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
if(args.size() == 3)
{
uint64_t axis = 1;
auto l1 = prog.add_instruction(op, args[0], args[1]);
auto l2 = prog.add_instruction(broadcast{axis}, l1, args[2]);
return prog.add_instruction(add{}, l1, l2);
}
return prog.add_instruction(op, args);
});
add_op("MatMul", [this](attribute_map, std::vector<instruction_ref> args) {
......@@ -90,9 +98,17 @@ struct onnx_parser
});
add_op("Reshape", [this](attribute_map attributes, std::vector<instruction_ref> args) {
reshape op;
if(args.size() == 1)
{
literal s = parse_value(attributes.at("shape"));
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
return prog.add_instruction(op, args);
}
if(args.size() == 2)
{
literal s = args[1]->lit;
s.visit([&](auto v) { copy(v, std::back_inserter(op.dims)); });
}
return prog.add_instruction(op, args[0]);
});
add_op("Constant", [this](attribute_map attributes, std::vector<instruction_ref>) {
literal v = parse_value(attributes.at("value"));
......@@ -194,7 +210,7 @@ struct onnx_parser
}
for(auto&& p : nodes)
{
this->parse_node(p.second.name());
this->parse_node(get_name(p.second));
}
}
......@@ -210,7 +226,7 @@ struct onnx_parser
{
if(nodes.count(input) > 0)
{
auto&& iname = nodes.at(input).name();
auto&& iname = get_name(nodes.at(input));
assert(name != iname);
this->parse_node(iname);
args.push_back(instructions.at(iname));
......@@ -241,12 +257,26 @@ struct onnx_parser
return result;
}
static std::string get_name(const onnx::NodeProto& node)
{
if(node.name().empty())
{
std::string generated = "migraph_unnamed_node";
for(auto&& output : node.output())
{
generated += "_" + output;
}
return generated;
}
return node.name();
}
static node_map get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node : graph.node())
{
result[node.name()] = node;
result[get_name(node)] = node;
for(auto&& output : node.output())
{
result[output] = node;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment