Commit f0a9d415 authored by Paul's avatar Paul
Browse files

Add onnx reader to analyzer

parent 6dd3cc0e
#ifndef RTG_GUARD_FALLTHROUGH_HPP
#define RTG_GUARD_FALLTHROUGH_HPP
namespace rtg {
#ifdef __clang__
#define RTG_FALLTHROUGH [[clang::fallthrough]]
#else
#define RTG_FALLTHROUGH
#endif
} // namespace rtg
#endif
...@@ -14,9 +14,9 @@ struct not_computable ...@@ -14,9 +14,9 @@ struct not_computable
struct convolution struct convolution
{ {
std::array<std::size_t, 2> padding = {0, 0}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {1, 1}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {1, 1}; std::array<std::size_t, 2> dilation = {{1, 1}};
std::string name() const std::string name() const
{ {
return "convolution[padding={" + to_string(padding) + "}, stride={" + to_string(stride) + return "convolution[padding={" + to_string(padding) + "}, stride={" + to_string(stride) +
...@@ -61,9 +61,9 @@ struct convolution ...@@ -61,9 +61,9 @@ struct convolution
struct pooling struct pooling
{ {
std::string mode; std::string mode;
std::array<std::size_t, 2> padding = {0, 0}; std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {1, 1}; std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {1, 1}; std::array<std::size_t, 2> lengths = {{1, 1}};
std::string name() const std::string name() const
{ {
return "pooling:" + mode + "[padding={" + to_string(padding) + "}, stride={" + return "pooling:" + mode + "[padding={" + to_string(padding) + "}, stride={" +
......
find_package(Protobuf REQUIRED) find_package(Protobuf REQUIRED)
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto) protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto)
include_directories(${CMAKE_CURRENT_BINARY_DIR}) add_library(onnx-proto STATIC ${PROTO_SRCS})
target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
add_executable(read_onnx read_onnx.cpp ${PROTO_SRCS}) add_executable(read_onnx read_onnx.cpp)
target_include_directories(read_onnx PUBLIC ${PROTOBUF_INCLUDE_DIR}) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx ${PROTOBUF_LIBRARY} rtg) target_link_libraries(read_onnx onnx-proto rtg)
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <unordered_map> #include <unordered_map>
#include <functional> #include <functional>
#include <rtg/fallthrough.hpp>
#include <rtg/program.hpp> #include <rtg/program.hpp>
#include <rtg/operators.hpp> #include <rtg/operators.hpp>
...@@ -21,7 +22,7 @@ struct unknown ...@@ -21,7 +22,7 @@ struct unknown
else else
return input.front(); return input.front();
} }
rtg::argument compute(std::vector<rtg::argument> input) const { throw "not computable"; } rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
}; };
template <class C, class T> template <class C, class T>
...@@ -84,7 +85,7 @@ struct onnx_parser ...@@ -84,7 +85,7 @@ struct onnx_parser
} }
return prog->add_instruction(op, args); return prog->add_instruction(op, args);
}); });
add_op("Relu", [this](attribute_map attributes, std::vector<rtg::instruction*> args) { add_op("Relu", [this](attribute_map, std::vector<rtg::instruction*> args) {
return prog->add_instruction(rtg::activation{"relu"}, args); return prog->add_instruction(rtg::activation{"relu"}, args);
}); });
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) { add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
...@@ -126,7 +127,7 @@ struct onnx_parser ...@@ -126,7 +127,7 @@ struct onnx_parser
nodes = get_nodes(graph); nodes = get_nodes(graph);
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
std::string name = input.name(); const std::string& name = input.name();
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
rtg::shape s = parse_type(input.type()); rtg::shape s = parse_type(input.type());
instructions[name] = prog->add_parameter(name, s); instructions[name] = prog->add_parameter(name, s);
...@@ -254,28 +255,28 @@ struct onnx_parser ...@@ -254,28 +255,28 @@ struct onnx_parser
static rtg::shape parse_type(const onnx::TypeProto& t) static rtg::shape parse_type(const onnx::TypeProto& t)
{ {
rtg::shape::type_t shape_type; rtg::shape::type_t shape_type{};
switch(t.tensor_type().elem_type()) switch(t.tensor_type().elem_type())
{ {
case onnx::TensorProto::UNDEFINED: case onnx::TensorProto::UNDEFINED:
break; // throw std::runtime_error("Unsupported type UNDEFINED"); break; // throw std::runtime_error("Unsupported type UNDEFINED");
case onnx::TensorProto::FLOAT: shape_type = rtg::shape::float_type; case onnx::TensorProto::FLOAT: shape_type = rtg::shape::float_type; break;
case onnx::TensorProto::UINT8: case onnx::TensorProto::UINT8:
break; // throw std::runtime_error("Unsupported type UINT8"); break; // throw std::runtime_error("Unsupported type UINT8");
case onnx::TensorProto::INT8: shape_type = rtg::shape::int8_type; case onnx::TensorProto::INT8: shape_type = rtg::shape::int8_type; break;
case onnx::TensorProto::UINT16: shape_type = rtg::shape::uint16_type; case onnx::TensorProto::UINT16: shape_type = rtg::shape::uint16_type; break;
case onnx::TensorProto::INT16: shape_type = rtg::shape::int16_type; case onnx::TensorProto::INT16: shape_type = rtg::shape::int16_type; break;
case onnx::TensorProto::INT32: shape_type = rtg::shape::int32_type; case onnx::TensorProto::INT32: shape_type = rtg::shape::int32_type; break;
case onnx::TensorProto::INT64: shape_type = rtg::shape::int64_type; case onnx::TensorProto::INT64: shape_type = rtg::shape::int64_type; break;
case onnx::TensorProto::STRING: case onnx::TensorProto::STRING:
break; // throw std::runtime_error("Unsupported type STRING"); break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
break; // throw std::runtime_error("Unsupported type BOOL"); break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
break; // throw std::runtime_error("Unsupported type FLOAT16"); break; // throw std::runtime_error("Unsupported type FLOAT16");
case onnx::TensorProto::DOUBLE: shape_type = rtg::shape::double_type; case onnx::TensorProto::DOUBLE: shape_type = rtg::shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = rtg::shape::uint32_type; case onnx::TensorProto::UINT32: shape_type = rtg::shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = rtg::shape::uint64_type; case onnx::TensorProto::UINT64: shape_type = rtg::shape::uint64_type; break;
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64"); break; // throw std::runtime_error("Unsupported type COMPLEX64");
case onnx::TensorProto::COMPLEX128: case onnx::TensorProto::COMPLEX128:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment