Commit bc2146b0 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added code to read initializer data from ONNX

parent b9890d91
...@@ -80,11 +80,13 @@ struct shape ...@@ -80,11 +80,13 @@ struct shape
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed with no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending order /// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
bool transposed() const; bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero /// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const; bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and not transposed. /// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
bool standard() const; bool standard() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
......
...@@ -20,6 +20,10 @@ add_executable(mnist mnist.cpp) ...@@ -20,6 +20,10 @@ add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist) rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraph_cpu migraph_onnx) target_link_libraries(mnist migraph_cpu migraph_onnx)
add_executable(resnet18 resnet18.cpp)
rocm_clang_tidy_check(resnet18)
target_link_libraries(resnet18 migraph_cpu migraph_onnx)
if(MIGRAPH_ENABLE_GPU) if(MIGRAPH_ENABLE_GPU)
add_executable(verify_onnx verify_onnx.cpp) add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx) rocm_clang_tidy_check(verify_onnx)
......
...@@ -285,13 +285,29 @@ struct onnx_parser ...@@ -285,13 +285,29 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph) void parse_graph(const onnx::GraphProto& graph)
{ {
nodes = get_nodes(graph); nodes = get_nodes(graph);
std::unordered_map<std::string, size_t> initializer_data;
auto cnt = 0;
for(auto&& f : graph.initializer())
{
initializer_data[f.name()] = cnt++;
}
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
const std::string& name = input.name(); const std::string& name = input.name();
// Does the input have an initializer?
if(initializer_data.find(name) != initializer_data.end())
{
auto idx = initializer_data[name];
auto t = graph.initializer()[idx];
instructions[name] = prog.add_literal(parse_tensor(t));
}
else
{
// TODO: Get shape of input parameter // TODO: Get shape of input parameter
shape s = parse_type(input.type()); shape s = parse_type(input.type());
instructions[name] = prog.add_parameter(name, s); instructions[name] = prog.add_parameter(name, s);
} }
}
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
this->parse_node(get_name(p.second)); this->parse_node(get_name(p.second));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment