Commit 1fca08da authored by Paul's avatar Paul
Browse files

Formatting

parent c6978f5d
......@@ -24,10 +24,7 @@ struct unknown
else
return input.front();
}
argument compute(shape, std::vector<argument>) const
{
RTG_THROW("not computable");
}
argument compute(shape, std::vector<argument>) const { RTG_THROW("not computable"); }
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
......@@ -51,8 +48,7 @@ struct onnx_parser
{
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func =
std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
......@@ -165,7 +161,7 @@ struct onnx_parser
{
const std::string& name = input.name();
// TODO: Get shape of input parameter
shape s = parse_type(input.type());
shape s = parse_type(input.type());
instructions[name] = prog.add_parameter(name, s);
}
for(auto&& p : nodes)
......@@ -176,7 +172,8 @@ struct onnx_parser
void parse_node(std::string name)
{
if(name.empty()) RTG_THROW("Onnx node must have a name");
if(name.empty())
RTG_THROW("Onnx node must have a name");
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
......@@ -247,8 +244,7 @@ struct onnx_parser
case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS:
return from_repeated(shape::float_type, attr.floats());
case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
case onnx::AttributeProto::STRINGS: return {};
case onnx::AttributeProto::TENSORS: return {};
......@@ -264,28 +260,21 @@ struct onnx_parser
{
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT:
return literal{
{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
return literal{{shape::float_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8:
return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::UINT16:
return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT16:
return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT32:
return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::INT64:
return literal{
{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
return literal{{shape::int64_type, dims}, t.int64_data().begin(), t.int64_data().end()};
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL:
return literal{
{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::DOUBLE:
return literal{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment