Unverified Commit 45bb91ea authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Usability updates to onnx parsing (#480)



* Add skip unknown operators flag

* Formatting

* Add flag to print program on error

* Formatting

* Fix compile error in py

* Formatting

* Workaround cppcheck error

* Initialize with struct

* Formatting

* Disable warning

* Formatting

* Add test for print errors

* Formatting

* Formatting

* Fix compiler error

* Formatting

* Formatting

* Formatting

* Use correct map

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 94addd98
......@@ -101,6 +101,8 @@ else()
-Wno-sign-conversion
-Wno-unused-command-line-argument
-Wno-weak-vtables
-Wno-c99-extensions
# -Wno-c++2a-designator
)
else()
list(APPEND CMAKE_COMPILER_WARNINGS
......
......@@ -192,12 +192,15 @@ program
parse_onnx
----------
.. py:function:: parse_onnx(filename, batch_size=1)
.. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false)
Load and parse an onnx file.
:param str filename: Path to file.
:param str batch_size: default batch size to use (if not specified in onnx file).
:param str default_dim_value: default batch size to use (if not specified in onnx file).
:param str map_input_dims: Explicitly specify the dims of an input.
:param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found.
:param str print_program_on_error: Print program if an error occurs.
:rtype: program
......
......@@ -3,7 +3,7 @@ add_library(migraphx_c
api.cpp
)
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 1.0)
rocm_set_soversion(migraphx_c 2.0)
rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu)
......
......@@ -34,6 +34,7 @@ struct loader
bool is_nhwc = true;
unsigned trim = 0;
bool optimize = false;
bool skip_unknown_operators = false;
void parse(argument_parser& ap)
{
......@@ -43,6 +44,10 @@ struct loader
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(batch, {"--batch"}, ap.help("Set batch size for model"));
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(skip_unknown_operators,
{"--skip-unknown-operators"},
ap.help("Skip unknown operators when parsing and continue to parse."),
ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
......@@ -62,10 +67,18 @@ struct loader
}
std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx")
p = parse_onnx(file, onnx_options{batch});
{
onnx_options options;
options.default_dim_value = batch;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = true;
p = parse_onnx(file, options);
}
else if(file_type == "tf")
{
p = parse_tf(file, tf_options{is_nhwc, batch});
}
}
else
{
if(model == "resnet50")
......
......@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser
struct onnx_options
{
/// default batch size to use (if not specified in onnx file)
std::size_t default_dim_value = 1;
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false;
/// Print program if an error occurs
bool print_program_on_error = false;
};
/// Create a program from an onnx file
......
......@@ -38,6 +38,7 @@ struct onnx_parser
bool is_pytorch = false;
std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
bool skip_unknown_operators = false;
std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> map_actv_funcs;
......@@ -1981,7 +1982,10 @@ struct onnx_parser
std::size_t output_num = static_cast<std::size_t>(node.output().size());
if(ops.count(node.op_type()) == 0)
{
if(skip_unknown_operators)
result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
else
MIGRAPHX_THROW("Unknown operator: " + node.op_type());
}
else
{
......@@ -2237,8 +2241,10 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
onnx_parser parser;
parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value;
parser.skip_unknown_operators = options.skip_unknown_operators;
#ifndef NDEBUG
if(options.print_program_on_error)
{
// Log the program when it can't be parsed
try
{
......@@ -2249,9 +2255,11 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
std::cerr << parser.prog << std::endl;
throw;
}
#else
}
else
{
parser.parse_from(std::forward<Ts>(xs)...);
#endif
}
return std::move(parser.prog);
}
......
......@@ -192,17 +192,23 @@ PYBIND11_MODULE(migraphx, m)
m.def("parse_onnx",
[](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::size_t value) {
bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims;
options.default_dim_value = value;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
py::arg("filename"),
py::arg("map_input_dims") = std::map<std::string, std::vector<std::size_t>>(),
py::arg("value") = 1);
py::arg("default_dim_value") = 1,
py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu")
......
......@@ -13,7 +13,9 @@
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = false)
{
auto prog = migraphx::parse_onnx(name);
migraphx::onnx_options options;
options.skip_unknown_operators = true;
auto prog = migraphx::parse_onnx(name, options);
if(eliminate_deadcode)
migraphx::run_passes(prog, {migraphx::dead_code_elimination{}});
......@@ -1717,6 +1719,18 @@ TEST_CASE(unknown_test)
EXPECT(p == prog);
}
TEST_CASE(unknown_test_throw)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx"); }));
}
TEST_CASE(unknown_test_throw_print_error)
{
migraphx::onnx_options options;
options.print_program_on_error = true;
EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx", options); }));
}
TEST_CASE(variable_batch_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment