Commit 2265e0d8 authored by Paul's avatar Paul
Browse files

Formatting

parent 7ac2d6a5
...@@ -24,10 +24,7 @@ struct unknown ...@@ -24,10 +24,7 @@ struct unknown
else else
return input.front(); return input.front();
} }
argument compute(shape, std::vector<argument>) const argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
{
RTG_THROW("not computable");
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x) friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{ {
os << x.name(); os << x.name();
...@@ -51,8 +48,7 @@ struct onnx_parser ...@@ -51,8 +48,7 @@ struct onnx_parser
{ {
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
...@@ -165,7 +161,7 @@ struct onnx_parser ...@@ -165,7 +161,7 @@ struct onnx_parser
{ {
const std::string& name = input.name(); const std::string& name = input.name();
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
shape s = parse_type(input.type()); shape s = parse_type(input.type());
instructions[name] = prog.add_parameter(name, s); instructions[name] = prog.add_parameter(name, s);
} }
for(auto&& p : nodes) for(auto&& p : nodes)
...@@ -245,8 +241,7 @@ struct onnx_parser ...@@ -245,8 +241,7 @@ struct onnx_parser
case onnx::AttributeProto::STRING: return {}; case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t()); case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {}; case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS: case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
return from_repeated(shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints()); case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
case onnx::AttributeProto::STRINGS: return {}; case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {}; case onnx::AttributeProto::TENSORS: return {};
...@@ -262,28 +257,21 @@ struct onnx_parser ...@@ -262,28 +257,21 @@ struct onnx_parser
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: case onnx::TensorProto::FLOAT:
return literal{ return literal{{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return literal{ return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return literal{ return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return literal{ return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
return literal{ return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64:
return literal{ return literal{{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return literal{ return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error(""); case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return literal{ return literal{
......
...@@ -25,7 +25,7 @@ int main(int argc, char const* argv[]) ...@@ -25,7 +25,7 @@ int main(int argc, char const* argv[])
if(argc > 1) if(argc > 1)
{ {
std::string file = argv[1]; std::string file = argv[1];
auto prog = rtg::parse_onnx(file); auto prog = rtg::parse_onnx(file);
prog.compile(rtg::cpu::cpu_target{}); prog.compile(rtg::cpu::cpu_target{});
auto s = prog.get_parameter_shape("Input3"); auto s = prog.get_parameter_shape("Input3");
auto input3 = get_tensor_argument(s); auto input3 = get_tensor_argument(s);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment