Commit bc2146b0 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added code to read initializer data from ONNX

parent b9890d91
...@@ -155,12 +155,12 @@ struct pooling ...@@ -155,12 +155,12 @@ struct pooling
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) / std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
static_cast<float>(stride[0]))) + static_cast<float>(stride[0]))) +
1)), 1)),
std::size_t(std::max<std::ptrdiff_t>( std::size_t(std::max<std::ptrdiff_t>(
1, 1,
std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) / std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
static_cast<float>(stride[1]))) + static_cast<float>(stride[1]))) +
1)), 1)),
// std::size_t(std::max<std::ptrdiff_t>( // std::size_t(std::max<std::ptrdiff_t>(
// 1, // 1,
......
...@@ -80,11 +80,13 @@ struct shape ...@@ -80,11 +80,13 @@ struct shape
/// Returns true if the shape is packed with no padding /// Returns true if the shape is packed with no padding
bool packed() const; bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending order /// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
bool transposed() const; bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero /// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const; bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and not transposed. /// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
bool standard() const; bool standard() const;
friend bool operator==(const shape& x, const shape& y); friend bool operator==(const shape& x, const shape& y);
......
...@@ -20,6 +20,10 @@ add_executable(mnist mnist.cpp) ...@@ -20,6 +20,10 @@ add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist) rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraph_cpu migraph_onnx) target_link_libraries(mnist migraph_cpu migraph_onnx)
add_executable(resnet18 resnet18.cpp)
rocm_clang_tidy_check(resnet18)
target_link_libraries(resnet18 migraph_cpu migraph_onnx)
if(MIGRAPH_ENABLE_GPU) if(MIGRAPH_ENABLE_GPU)
add_executable(verify_onnx verify_onnx.cpp) add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx) rocm_clang_tidy_check(verify_onnx)
......
...@@ -285,12 +285,28 @@ struct onnx_parser ...@@ -285,12 +285,28 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph) void parse_graph(const onnx::GraphProto& graph)
{ {
nodes = get_nodes(graph); nodes = get_nodes(graph);
std::unordered_map<std::string, size_t> initializer_data;
auto cnt = 0;
for(auto&& f : graph.initializer())
{
initializer_data[f.name()] = cnt++;
}
for(auto&& input : graph.input()) for(auto&& input : graph.input())
{ {
const std::string& name = input.name(); const std::string& name = input.name();
// TODO: Get shape of input parameter // Does the input have an initializer?
shape s = parse_type(input.type()); if(initializer_data.find(name) != initializer_data.end())
instructions[name] = prog.add_parameter(name, s); {
auto idx = initializer_data[name];
auto t = graph.initializer()[idx];
instructions[name] = prog.add_literal(parse_tensor(t));
}
else
{
// TODO: Get shape of input parameter
shape s = parse_type(input.type());
instructions[name] = prog.add_parameter(name, s);
}
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment