read_onnx.cpp 3.17 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6

#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
Paul's avatar
Paul committed
7
#include <unordered_map>
Paul's avatar
Paul committed
8

Paul's avatar
Paul committed
9
10
11
12
13
14
15
16
17
18
19
#include <rtg/program.hpp>

struct unknown
{
    std::string op;
    std::string name() const
    {
        return "unknown:"+op;
    }
    rtg::shape compute_shape(std::vector<rtg::shape> input) const
    {
Paul's avatar
Paul committed
20
21
        if(input.empty()) return {};
        else return input.front();
Paul's avatar
Paul committed
22
23
24
25
26
27
    }
    rtg::argument compute(std::vector<rtg::argument> input) const
    {
        throw "not computable";
    }
};
Paul's avatar
Paul committed
28

Paul's avatar
Paul committed
29
struct onnx_parser 
Paul's avatar
Paul committed
30
{
Paul's avatar
Paul committed
31
32
33
34
35
    std::unordered_map<std::string, onnx::NodeProto> nodes;
    std::unordered_map<std::string, rtg::instruction*> instructions;
    std::shared_ptr<rtg::program> prog = std::make_shared<rtg::program>();

    void parse_graph(const onnx::GraphProto& graph)
Paul's avatar
Paul committed
36
    {
Paul's avatar
Paul committed
37
38
39
40
41
42
43
44
45
46
47
        nodes = get_nodes(graph);
        for(auto&& input:graph.input())
        {
            std::string name = input.name();
            // TODO: Get shape of input parameter
            instructions[name] = prog->add_parameter(name, rtg::shape{});
        }
        for(auto&& p:nodes)
        {
            this->parse_node(p.second.name());
        }
Paul's avatar
Paul committed
48
49
    }

Paul's avatar
Paul committed
50
    void parse_node(std::string name)
Paul's avatar
Paul committed
51
    {
Paul's avatar
Paul committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        if (instructions.count(name) == 0)
        {
            auto&& node = nodes.at(name);
            std::vector<rtg::instruction*> args;
            for(auto&& input:node.input())
            {
                if(nodes.count(input) > 0)
                {
                    auto&& iname = nodes.at(input).name();
                    this->parse_node(iname);
                    args.push_back(instructions.at(iname));
                }
                else
                {
                    args.push_back(instructions.at(input));
                }
            }
            instructions[name] = prog->add_instruction(unknown{node.op_type()}, args);
        }
Paul's avatar
Paul committed
71
72
    }

Paul's avatar
Paul committed
73
74
75
    static std::unordered_map<std::string, onnx::AttributeProto> get_attributes(const onnx::NodeProto& node)
    {
        std::unordered_map<std::string, onnx::AttributeProto> result;
Paul's avatar
Paul committed
76
77
        for(auto&& attr:node.attribute())
        {
Paul's avatar
Paul committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
            result[attr.name()] = attr;
        }
        return result;
    }

    static std::unordered_map<std::string, onnx::NodeProto> get_nodes(const onnx::GraphProto& graph)
    {
        std::unordered_map<std::string, onnx::NodeProto> result;
        for(auto&& node:graph.node())
        {
            result[node.name()] = node;
            for(auto&& output:node.output())
            {
                result[output] = node;
            }

        }
        return result;
    }
};

std::shared_ptr<rtg::program> parse_onnx(std::istream& is)
{
    onnx_parser parser;
    onnx::ModelProto model;
    if(model.ParseFromIstream(&is)) {
        if(model.has_graph()) {
            parser.parse_graph(model.graph());
Paul's avatar
Paul committed
106
        }
Paul's avatar
Paul committed
107
108
    } else {
        throw "Failed reading";
Paul's avatar
Paul committed
109
    }
Paul's avatar
Paul committed
110
    return parser.prog;
Paul's avatar
Paul committed
111
112
}

Paul's avatar
Paul committed
113
114
115
116
117
118
int main(int argc, char const *argv[])
{
    if(argc > 1)
    {
        std::string file = argv[1];
        std::fstream input(file.c_str(), std::ios::in | std::ios::binary);
Paul's avatar
Paul committed
119
120
        auto prog = parse_onnx(input);
        prog->print();
Paul's avatar
Paul committed
121
122
    }
}