"testing/python/vscode:/vscode.git/clone" did not exist on "d4b6d0945e7a45db3883c13ed8d7049b568e0e94"
Commit 2f8e4e83 authored by Paul's avatar Paul
Browse files

Print ir from onnx file

parent 0fc287d3
...@@ -13,11 +13,11 @@ struct instruction ...@@ -13,11 +13,11 @@ struct instruction
instruction() {} instruction() {}
instruction(operand o, shape r, std::vector<instruction*> args) instruction(operand o, shape r, std::vector<instruction*> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)) : op(std::move(o)), result(std::move(r)), arguments(std::move(args)), lit()
{} {}
instruction(literal l) instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l)) : op(builtin::literal{}), result(l.get_shape()), arguments(), lit(std::move(l))
{} {}
operand op; operand op;
......
...@@ -6,11 +6,17 @@ ...@@ -6,11 +6,17 @@
#include <rtg/instruction.hpp> #include <rtg/instruction.hpp>
#include <rtg/operand.hpp> #include <rtg/operand.hpp>
#include <rtg/builtin.hpp> #include <rtg/builtin.hpp>
#include <algorithm>
namespace rtg { namespace rtg {
struct program struct program
{ {
// TODO: A program should be copyable
program() = default;
program(const program&) = delete;
program& operator=(const program&) = delete;
template<class... Ts> template<class... Ts>
instruction * add_instruction(operand op, Ts*... args) instruction * add_instruction(operand op, Ts*... args)
{ {
...@@ -18,6 +24,16 @@ struct program ...@@ -18,6 +24,16 @@ struct program
instructions.push_back({op, r, {args...}}); instructions.push_back({op, r, {args...}});
return std::addressof(instructions.back()); return std::addressof(instructions.back());
} }
instruction * add_instruction(operand op, std::vector<instruction*> args)
{
assert(std::all_of(args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) && "Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size());
std::transform(args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
shape r = op.compute_shape(shapes);
instructions.push_back({op, r, args});
assert(instructions.back().arguments == args);
return std::addressof(instructions.back());
}
template<class... Ts> template<class... Ts>
instruction * add_literal(Ts&&... xs) instruction * add_literal(Ts&&... xs)
{ {
...@@ -36,6 +52,11 @@ struct program ...@@ -36,6 +52,11 @@ struct program
// TODO: Change to stream operator // TODO: Change to stream operator
void print() const; void print() const;
bool has_instruction(const instruction * ins) const
{
return std::find_if(instructions.begin(), instructions.end(), [&](const instruction& x) {return ins == std::addressof(x); }) != instructions.end();
}
private: private:
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
std::list<instruction> instructions; std::list<instruction> instructions;
......
...@@ -59,6 +59,7 @@ void program::print() const ...@@ -59,6 +59,7 @@ void program::print() const
char delim = '('; char delim = '(';
for(auto&& arg:ins.arguments) for(auto&& arg:ins.arguments)
{ {
assert(this->has_instruction(arg) && "Instruction not found");
std::cout << delim << names.at(arg); std::cout << delim << names.at(arg);
delim = ','; delim = ',';
} }
...@@ -70,7 +71,7 @@ void program::print() const ...@@ -70,7 +71,7 @@ void program::print() const
std::cout << std::endl; std::cout << std::endl;
names.emplace(std::addressof(ins), var_name); names.emplace(std::addressof(ins), var_name);
count++;
} }
} }
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
struct unknown struct unknown
{ {
rtg::shape s;
std::string op; std::string op;
std::string name() const std::string name() const
{ {
...@@ -18,7 +17,8 @@ struct unknown ...@@ -18,7 +17,8 @@ struct unknown
} }
rtg::shape compute_shape(std::vector<rtg::shape> input) const rtg::shape compute_shape(std::vector<rtg::shape> input) const
{ {
return s; if(input.empty()) return {};
else return input.front();
} }
rtg::argument compute(std::vector<rtg::argument> input) const rtg::argument compute(std::vector<rtg::argument> input) const
{ {
...@@ -26,43 +26,88 @@ struct unknown ...@@ -26,43 +26,88 @@ struct unknown
} }
}; };
std::unordered_map<std::string, onnx::AttributeProto> get_attributes(const onnx::NodeProto& node) struct onnx_parser
{ {
std::unordered_map<std::string, onnx::AttributeProto> result; std::unordered_map<std::string, onnx::NodeProto> nodes;
for(auto&& attr:node.attribute()) std::unordered_map<std::string, rtg::instruction*> instructions;
std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>();
void parse_graph(const onnx::GraphProto& graph)
{ {
result[attr.name()] = attr; nodes = get_nodes(graph);
for(auto&& input:graph.input())
{
std::string name = input.name();
// TODO: Get shape of input parameter
instructions[name] = prog->add_parameter(name, rtg::shape{});
}
for(auto&& p:nodes)
{
this->parse_node(p.second.name());
}
} }
return result;
}
std::unordered_map<std::string, onnx::NodeProto> get_nodes(const onnx::GraphProto& graph) void parse_node(std::string name)
{
std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node:graph.node())
{ {
result[node.name()] = node; if (instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
std::vector<rtg::instruction*> args;
for(auto&& input:node.input())
{
if(nodes.count(input) > 0)
{
auto&& iname = nodes.at(input).name();
this->parse_node(iname);
args.push_back(instructions.at(iname));
}
else
{
args.push_back(instructions.at(input));
}
}
instructions[name] = prog->add_instruction(unknown{node.op_type()}, args);
}
} }
return result;
}
void parse_graph(onnx::GraphProto graph) static std::unordered_map<std::string, onnx::AttributeProto> get_attributes(const onnx::NodeProto& node)
{ {
std::cout << "Graph name: " << graph.name() << std::endl; std::unordered_map<std::string, onnx::AttributeProto> result;
for(onnx::NodeProto node:graph.node()) {
std::cout << "Layer: " << node.op_type() << std::endl;
std::cout << " Name: " << node.name() << std::endl;
if(node.input_size() > 0)
std::cout << " Input: " << node.input(0) << std::endl;
if(node.output_size() > 0)
std::cout << " Output: " << node.output(0) << std::endl;
std::cout << " Attributes: " << std::endl;
for(auto&& attr:node.attribute()) for(auto&& attr:node.attribute())
{ {
std::cout << " " << attr.name() << std::endl; result[attr.name()] = attr;
}
return result;
}
static std::unordered_map<std::string, onnx::NodeProto> get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node:graph.node())
{
result[node.name()] = node;
for(auto&& output:node.output())
{
result[output] = node;
}
}
return result;
}
};
std::shared_ptr<rtg::program> parse_onnx(std::istream& is)
{
onnx_parser parser;
onnx::ModelProto model;
if(model.ParseFromIstream(&is)) {
if(model.has_graph()) {
parser.parse_graph(model.graph());
} }
} else {
throw "Failed reading";
} }
return parser.prog;
} }
int main(int argc, char const *argv[]) int main(int argc, char const *argv[])
...@@ -71,18 +116,7 @@ int main(int argc, char const *argv[]) ...@@ -71,18 +116,7 @@ int main(int argc, char const *argv[])
{ {
std::string file = argv[1]; std::string file = argv[1];
std::fstream input(file.c_str(), std::ios::in | std::ios::binary); std::fstream input(file.c_str(), std::ios::in | std::ios::binary);
onnx::ModelProto model; auto prog = parse_onnx(input);
if(model.ParseFromIstream(&input)) { prog->print();
std::cout << "Model version: " << model.model_version() << std::endl;
std::cout << "Producer name: " << model.producer_name() << std::endl;
std::cout << "Producer version: " << model.release_producer_version() << std::endl;
if(model.has_graph()) {
std::cout << "Model has graph" << std::endl;
parse_graph(model.graph());
}
} else {
std::cout << "Failed reading: " << file << std::endl;
}
} }
} }
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace rtg { namespace rtg {
shape::shape() shape::shape()
: type_(float_type), lens_(), strides_() : type_(float_type), lens_(), strides_(), packed_(false)
{} {}
shape::shape(type_t t) shape::shape(type_t t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment