Commit 2f8e4e83 authored by Paul's avatar Paul
Browse files

Print ir from onnx file

parent 0fc287d3
......@@ -13,11 +13,11 @@ struct instruction
instruction() {}
instruction(operand o, shape r, std::vector<instruction*> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)), lit()
{}
instruction(literal l)
: op(builtin::literal{}), result(l.get_shape()), lit(std::move(l))
: op(builtin::literal{}), result(l.get_shape()), arguments(), lit(std::move(l))
{}
operand op;
......
......@@ -6,11 +6,17 @@
#include <rtg/instruction.hpp>
#include <rtg/operand.hpp>
#include <rtg/builtin.hpp>
#include <algorithm>
namespace rtg {
struct program
{
// TODO: A program should be copyable
program() = default;
program(const program&) = delete;
program& operator=(const program&) = delete;
template<class... Ts>
instruction * add_instruction(operand op, Ts*... args)
{
......@@ -18,6 +24,16 @@ struct program
instructions.push_back({op, r, {args...}});
return std::addressof(instructions.back());
}
instruction * add_instruction(operand op, std::vector<instruction*> args)
{
assert(std::all_of(args.begin(), args.end(), [&](instruction* x) { return has_instruction(x); }) && "Argument is not an exisiting instruction");
std::vector<shape> shapes(args.size());
std::transform(args.begin(), args.end(), shapes.begin(), [](instruction* ins) { return ins->result; });
shape r = op.compute_shape(shapes);
instructions.push_back({op, r, args});
assert(instructions.back().arguments == args);
return std::addressof(instructions.back());
}
template<class... Ts>
instruction * add_literal(Ts&&... xs)
{
......@@ -36,6 +52,11 @@ struct program
// TODO: Change to stream operator
void print() const;
bool has_instruction(const instruction * ins) const
{
return std::find_if(instructions.begin(), instructions.end(), [&](const instruction& x) {return ins == std::addressof(x); }) != instructions.end();
}
private:
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
......
......@@ -59,6 +59,7 @@ void program::print() const
char delim = '(';
for(auto&& arg:ins.arguments)
{
assert(this->has_instruction(arg) && "Instruction not found");
std::cout << delim << names.at(arg);
delim = ',';
}
......@@ -70,7 +71,7 @@ void program::print() const
std::cout << std::endl;
names.emplace(std::addressof(ins), var_name);
count++;
}
}
......
......@@ -10,7 +10,6 @@
struct unknown
{
rtg::shape s;
std::string op;
std::string name() const
{
......@@ -18,7 +17,8 @@ struct unknown
}
rtg::shape compute_shape(std::vector<rtg::shape> input) const
{
return s;
if(input.empty()) return {};
else return input.front();
}
rtg::argument compute(std::vector<rtg::argument> input) const
{
......@@ -26,43 +26,88 @@ struct unknown
}
};
std::unordered_map<std::string, onnx::AttributeProto> get_attributes(const onnx::NodeProto& node)
struct onnx_parser
{
std::unordered_map<std::string, onnx::AttributeProto> result;
for(auto&& attr:node.attribute())
std::unordered_map<std::string, onnx::NodeProto> nodes;
std::unordered_map<std::string, rtg::instruction*> instructions;
std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>();
void parse_graph(const onnx::GraphProto& graph)
{
result[attr.name()] = attr;
nodes = get_nodes(graph);
for(auto&& input:graph.input())
{
std::string name = input.name();
// TODO: Get shape of input parameter
instructions[name] = prog->add_parameter(name, rtg::shape{});
}
for(auto&& p:nodes)
{
this->parse_node(p.second.name());
}
}
return result;
}
std::unordered_map<std::string, onnx::NodeProto> get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node:graph.node())
void parse_node(std::string name)
{
result[node.name()] = node;
if (instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
std::vector<rtg::instruction*> args;
for(auto&& input:node.input())
{
if(nodes.count(input) > 0)
{
auto&& iname = nodes.at(input).name();
this->parse_node(iname);
args.push_back(instructions.at(iname));
}
else
{
args.push_back(instructions.at(input));
}
}
instructions[name] = prog->add_instruction(unknown{node.op_type()}, args);
}
}
return result;
}
void parse_graph(onnx::GraphProto graph)
{
std::cout << "Graph name: " << graph.name() << std::endl;
for(onnx::NodeProto node:graph.node()) {
std::cout << "Layer: " << node.op_type() << std::endl;
std::cout << " Name: " << node.name() << std::endl;
if(node.input_size() > 0)
std::cout << " Input: " << node.input(0) << std::endl;
if(node.output_size() > 0)
std::cout << " Output: " << node.output(0) << std::endl;
std::cout << " Attributes: " << std::endl;
static std::unordered_map<std::string, onnx::AttributeProto> get_attributes(const onnx::NodeProto& node)
{
std::unordered_map<std::string, onnx::AttributeProto> result;
for(auto&& attr:node.attribute())
{
std::cout << " " << attr.name() << std::endl;
result[attr.name()] = attr;
}
return result;
}
static std::unordered_map<std::string, onnx::NodeProto> get_nodes(const onnx::GraphProto& graph)
{
std::unordered_map<std::string, onnx::NodeProto> result;
for(auto&& node:graph.node())
{
result[node.name()] = node;
for(auto&& output:node.output())
{
result[output] = node;
}
}
return result;
}
};
std::shared_ptr<rtg::program> parse_onnx(std::istream& is)
{
onnx_parser parser;
onnx::ModelProto model;
if(model.ParseFromIstream(&is)) {
if(model.has_graph()) {
parser.parse_graph(model.graph());
}
} else {
throw "Failed reading";
}
return parser.prog;
}
int main(int argc, char const *argv[])
......@@ -71,18 +116,7 @@ int main(int argc, char const *argv[])
{
std::string file = argv[1];
std::fstream input(file.c_str(), std::ios::in | std::ios::binary);
onnx::ModelProto model;
if(model.ParseFromIstream(&input)) {
std::cout << "Model version: " << model.model_version() << std::endl;
std::cout << "Producer name: " << model.producer_name() << std::endl;
std::cout << "Producer version: " << model.release_producer_version() << std::endl;
if(model.has_graph()) {
std::cout << "Model has graph" << std::endl;
parse_graph(model.graph());
}
} else {
std::cout << "Failed reading: " << file << std::endl;
}
auto prog = parse_onnx(input);
prog->print();
}
}
......@@ -8,7 +8,7 @@
namespace rtg {
shape::shape()
: type_(float_type), lens_(), strides_()
: type_(float_type), lens_(), strides_(), packed_(false)
{}
shape::shape(type_t t)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment