"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "7746373c1999b1a02da70addd0bf43e4bddb1baa"
Commit bc2146b0 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added code to read initializer data from ONNX

parent b9890d91
......@@ -155,12 +155,12 @@ struct pooling
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[2] + 2 * padding[0] - lengths[0]) /
static_cast<float>(stride[0]))) +
static_cast<float>(stride[0]))) +
1)),
std::size_t(std::max<std::ptrdiff_t>(
1,
std::ptrdiff_t(std::floor((input.lens()[3] + 2 * padding[1] - lengths[1]) /
static_cast<float>(stride[1]))) +
static_cast<float>(stride[1]))) +
1)),
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
......
......@@ -80,11 +80,13 @@ struct shape
/// Returns true if the shape is packed with no padding
bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending order
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// order
bool transposed() const;
/// Returns true if the shape is broadcasting a dimension. That is, one of the strides are zero
bool broadcasted() const;
/// Returns true if the shape is in its standard format. That is, the shape is both packed and not transposed.
/// Returns true if the shape is in its standard format. That is, the shape is both packed and
/// not transposed.
bool standard() const;
friend bool operator==(const shape& x, const shape& y);
......
......@@ -20,6 +20,10 @@ add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraph_cpu migraph_onnx)
add_executable(resnet18 resnet18.cpp)
rocm_clang_tidy_check(resnet18)
target_link_libraries(resnet18 migraph_cpu migraph_onnx)
if(MIGRAPH_ENABLE_GPU)
add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx)
......
......@@ -285,12 +285,28 @@ struct onnx_parser
void parse_graph(const onnx::GraphProto& graph)
{
nodes = get_nodes(graph);
std::unordered_map<std::string, size_t> initializer_data;
auto cnt = 0;
for(auto&& f : graph.initializer())
{
initializer_data[f.name()] = cnt++;
}
for(auto&& input : graph.input())
{
const std::string& name = input.name();
// TODO: Get shape of input parameter
shape s = parse_type(input.type());
instructions[name] = prog.add_parameter(name, s);
// Does the input have an initializer?
if(initializer_data.find(name) != initializer_data.end())
{
auto idx = initializer_data[name];
auto t = graph.initializer()[idx];
instructions[name] = prog.add_literal(parse_tensor(t));
}
else
{
// TODO: Get shape of input parameter
shape s = parse_type(input.type());
instructions[name] = prog.add_parameter(name, s);
}
}
for(auto&& p : nodes)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment