#include #include #include #include #include #include #include struct unknown { std::string op; std::string name() const { return "unknown:"+op; } rtg::shape compute_shape(std::vector input) const { if(input.empty()) return {}; else return input.front(); } rtg::argument compute(std::vector input) const { throw "not computable"; } }; struct onnx_parser { std::unordered_map nodes; std::unordered_map instructions; std::shared_ptr prog = std::make_shared(); void parse_graph(const onnx::GraphProto& graph) { nodes = get_nodes(graph); for(auto&& input:graph.input()) { std::string name = input.name(); // TODO: Get shape of input parameter instructions[name] = prog->add_parameter(name, rtg::shape{}); } for(auto&& p:nodes) { this->parse_node(p.second.name()); } } void parse_node(std::string name) { if (instructions.count(name) == 0) { auto&& node = nodes.at(name); std::vector args; for(auto&& input:node.input()) { if(nodes.count(input) > 0) { auto&& iname = nodes.at(input).name(); this->parse_node(iname); args.push_back(instructions.at(iname)); } else { args.push_back(instructions.at(input)); } } instructions[name] = prog->add_instruction(unknown{node.op_type()}, args); } } static std::unordered_map get_attributes(const onnx::NodeProto& node) { std::unordered_map result; for(auto&& attr:node.attribute()) { result[attr.name()] = attr; } return result; } static std::unordered_map get_nodes(const onnx::GraphProto& graph) { std::unordered_map result; for(auto&& node:graph.node()) { result[node.name()] = node; for(auto&& output:node.output()) { result[output] = node; } } return result; } }; std::shared_ptr parse_onnx(std::istream& is) { onnx_parser parser; onnx::ModelProto model; if(model.ParseFromIstream(&is)) { if(model.has_graph()) { parser.parse_graph(model.graph()); } } else { throw "Failed reading"; } return parser.prog; } int main(int argc, char const *argv[]) { if(argc > 1) { std::string file = argv[1]; std::fstream input(file.c_str(), std::ios::in | std::ios::binary); auto prog = parse_onnx(input); prog->print(); } }