Commit f0a9d415 authored by Paul's avatar Paul
Browse files

Add onnx reader to analyzer

parent 6dd3cc0e
#ifndef RTG_GUARD_FALLTHROUGH_HPP
#define RTG_GUARD_FALLTHROUGH_HPP
namespace rtg {
#ifdef __clang__
#define RTG_FALLTHROUGH [[clang::fallthrough]]
#else
#define RTG_FALLTHROUGH
#endif
} // namespace rtg
#endif
......@@ -14,9 +14,9 @@ struct not_computable
struct convolution
{
std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> dilation = {1, 1};
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> dilation = {{1, 1}};
std::string name() const
{
return "convolution[padding={" + to_string(padding) + "}, stride={" + to_string(stride) +
......@@ -61,9 +61,9 @@ struct convolution
struct pooling
{
std::string mode;
std::array<std::size_t, 2> padding = {0, 0};
std::array<std::size_t, 2> stride = {1, 1};
std::array<std::size_t, 2> lengths = {1, 1};
std::array<std::size_t, 2> padding = {{0, 0}};
std::array<std::size_t, 2> stride = {{1, 1}};
std::array<std::size_t, 2> lengths = {{1, 1}};
std::string name() const
{
return "pooling:" + mode + "[padding={" + to_string(padding) + "}, stride={" +
......
find_package(Protobuf REQUIRED)
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS onnx.proto)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
add_library(onnx-proto STATIC ${PROTO_SRCS})
target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR} ${PROTOBUF_INCLUDE_DIR})
target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
add_executable(read_onnx read_onnx.cpp ${PROTO_SRCS})
target_include_directories(read_onnx PUBLIC ${PROTOBUF_INCLUDE_DIR})
target_link_libraries(read_onnx ${PROTOBUF_LIBRARY} rtg)
add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx onnx-proto rtg)
......@@ -7,6 +7,7 @@
#include <unordered_map>
#include <functional>
#include <rtg/fallthrough.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
......@@ -21,7 +22,7 @@ struct unknown
else
return input.front();
}
rtg::argument compute(std::vector<rtg::argument> input) const { throw "not computable"; }
rtg::argument compute(std::vector<rtg::argument>) const { RTG_THROW("not computable"); }
};
template <class C, class T>
......@@ -84,7 +85,7 @@ struct onnx_parser
}
return prog->add_instruction(op, args);
});
add_op("Relu", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
add_op("Relu", [this](attribute_map, std::vector<rtg::instruction*> args) {
return prog->add_instruction(rtg::activation{"relu"}, args);
});
add_op("Reshape", [this](attribute_map attributes, std::vector<rtg::instruction*> args) {
......@@ -126,7 +127,7 @@ struct onnx_parser
nodes = get_nodes(graph);
for(auto&& input : graph.input())
{
std::string name = input.name();
const std::string& name = input.name();
// TODO: Get shape of input parameter
rtg::shape s = parse_type(input.type());
instructions[name] = prog->add_parameter(name, s);
......@@ -254,28 +255,28 @@ struct onnx_parser
static rtg::shape parse_type(const onnx::TypeProto& t)
{
rtg::shape::type_t shape_type;
rtg::shape::type_t shape_type{};
switch(t.tensor_type().elem_type())
{
case onnx::TensorProto::UNDEFINED:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case onnx::TensorProto::FLOAT: shape_type = rtg::shape::float_type;
case onnx::TensorProto::FLOAT: shape_type = rtg::shape::float_type; break;
case onnx::TensorProto::UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case onnx::TensorProto::INT8: shape_type = rtg::shape::int8_type;
case onnx::TensorProto::UINT16: shape_type = rtg::shape::uint16_type;
case onnx::TensorProto::INT16: shape_type = rtg::shape::int16_type;
case onnx::TensorProto::INT32: shape_type = rtg::shape::int32_type;
case onnx::TensorProto::INT64: shape_type = rtg::shape::int64_type;
case onnx::TensorProto::INT8: shape_type = rtg::shape::int8_type; break;
case onnx::TensorProto::UINT16: shape_type = rtg::shape::uint16_type; break;
case onnx::TensorProto::INT16: shape_type = rtg::shape::int16_type; break;
case onnx::TensorProto::INT32: shape_type = rtg::shape::int32_type; break;
case onnx::TensorProto::INT64: shape_type = rtg::shape::int64_type; break;
case onnx::TensorProto::STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16:
break; // throw std::runtime_error("Unsupported type FLOAT16");
case onnx::TensorProto::DOUBLE: shape_type = rtg::shape::double_type;
case onnx::TensorProto::UINT32: shape_type = rtg::shape::uint32_type;
case onnx::TensorProto::UINT64: shape_type = rtg::shape::uint64_type;
case onnx::TensorProto::DOUBLE: shape_type = rtg::shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = rtg::shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = rtg::shape::uint64_type; break;
case onnx::TensorProto::COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case onnx::TensorProto::COMPLEX128:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment