Commit c6978f5d authored by Paul's avatar Paul
Browse files

Add extra checks in onnx parser

parent 2265e0d8
...@@ -24,7 +24,10 @@ struct unknown ...@@ -24,7 +24,10 @@ struct unknown
else else
return input.front(); return input.front();
} }
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); } argument compute(shape, std::vector<argument>) const
{
RTG_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x) friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{ {
os << x.name(); os << x.name();
...@@ -48,7 +51,8 @@ struct onnx_parser ...@@ -48,7 +51,8 @@ struct onnx_parser
{ {
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func =
std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
...@@ -161,7 +165,7 @@ struct onnx_parser ...@@ -161,7 +165,7 @@ struct onnx_parser
{ {
const std::string& name = input.name(); const std::string& name = input.name();
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
shape s = parse_type(input.type()); shape s = parse_type(input.type());
instructions[name] = prog.add_parameter(name, s); instructions[name] = prog.add_parameter(name, s);
} }
for(auto&& p : nodes) for(auto&& p : nodes)
...@@ -172,6 +176,7 @@ struct onnx_parser ...@@ -172,6 +176,7 @@ struct onnx_parser
void parse_node(std::string name) void parse_node(std::string name)
{ {
if(name.empty()) RTG_THROW("Onnx node must have a name");
if(instructions.count(name) == 0) if(instructions.count(name) == 0)
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
...@@ -181,6 +186,7 @@ struct onnx_parser ...@@ -181,6 +186,7 @@ struct onnx_parser
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = nodes.at(input).name(); auto&& iname = nodes.at(input).name();
assert(name != iname);
this->parse_node(iname); this->parse_node(iname);
args.push_back(instructions.at(iname)); args.push_back(instructions.at(iname));
} }
...@@ -241,7 +247,8 @@ struct onnx_parser ...@@ -241,7 +247,8 @@ struct onnx_parser
case onnx::AttributeProto::STRING: return {}; case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t()); case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {}; case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats()); case onnx::AttributeProto::FLOATS:
return from_repeated(shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints()); case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
case onnx::AttributeProto::STRINGS: return {}; case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {}; case onnx::AttributeProto::TENSORS: return {};
...@@ -257,21 +264,28 @@ struct onnx_parser ...@@ -257,21 +264,28 @@ struct onnx_parser
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: case onnx::TensorProto::FLOAT:
return literal{{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()}; return literal{
{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64:
return literal{{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()}; return literal{
{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()}; return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error(""); case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return literal{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment