Unverified Commit 45bb91ea authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Usability updates to onnx parsing (#480)



* Add skip unknown operators flag

* Formatting

* Add flag to print program on error

* Formatting

* Fix compile error in py

* Formatting

* Workaround cppcheck error

* Initialize with struct

* Formatting

* Disable warning

* Formatting

* Add test for print errors

* Formatting

* Formatting

* Fix compiler error

* Formatting

* Formatting

* Formatting

* Use correct map

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 94addd98
...@@ -101,6 +101,8 @@ else() ...@@ -101,6 +101,8 @@ else()
-Wno-sign-conversion -Wno-sign-conversion
-Wno-unused-command-line-argument -Wno-unused-command-line-argument
-Wno-weak-vtables -Wno-weak-vtables
-Wno-c99-extensions
# -Wno-c++2a-designator
) )
else() else()
list(APPEND CMAKE_COMPILER_WARNINGS list(APPEND CMAKE_COMPILER_WARNINGS
......
...@@ -192,12 +192,15 @@ program ...@@ -192,12 +192,15 @@ program
parse_onnx parse_onnx
---------- ----------
.. py:function:: parse_onnx(filename, batch_size=1) .. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false)
Load and parse an onnx file. Load and parse an onnx file.
:param str filename: Path to file. :param str filename: Path to file.
:param str batch_size: default batch size to use (if not specified in onnx file). :param str default_dim_value: default batch size to use (if not specified in onnx file).
:param str map_input_dims: Explicitly specify the dims of an input.
:param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found.
:param str print_program_on_error: Print program if an error occurs.
:rtype: program :rtype: program
......
...@@ -3,7 +3,7 @@ add_library(migraphx_c ...@@ -3,7 +3,7 @@ add_library(migraphx_c
api.cpp api.cpp
) )
set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c) set_target_properties(migraphx_c PROPERTIES EXPORT_NAME c)
rocm_set_soversion(migraphx_c 1.0) rocm_set_soversion(migraphx_c 2.0)
rocm_clang_tidy_check(migraphx_c) rocm_clang_tidy_check(migraphx_c)
target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu) target_link_libraries(migraphx_c PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_cpu)
......
...@@ -34,6 +34,7 @@ struct loader ...@@ -34,6 +34,7 @@ struct loader
bool is_nhwc = true; bool is_nhwc = true;
unsigned trim = 0; unsigned trim = 0;
bool optimize = false; bool optimize = false;
bool skip_unknown_operators = false;
void parse(argument_parser& ap) void parse(argument_parser& ap)
{ {
...@@ -43,6 +44,10 @@ struct loader ...@@ -43,6 +44,10 @@ struct loader
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf")); ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
ap(batch, {"--batch"}, ap.help("Set batch size for model")); ap(batch, {"--batch"}, ap.help("Set batch size for model"));
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true)); ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(skip_unknown_operators,
{"--skip-unknown-operators"},
ap.help("Skip unknown operators when parsing and continue to parse."),
ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false)); ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end")); ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true)); ap(optimize, {"--optimize", "-O"}, ap.help("Optimize when reading"), ap.set_value(true));
...@@ -62,10 +67,18 @@ struct loader ...@@ -62,10 +67,18 @@ struct loader
} }
std::cout << "Reading: " << file << std::endl; std::cout << "Reading: " << file << std::endl;
if(file_type == "onnx") if(file_type == "onnx")
p = parse_onnx(file, onnx_options{batch}); {
onnx_options options;
options.default_dim_value = batch;
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = true;
p = parse_onnx(file, options);
}
else if(file_type == "tf") else if(file_type == "tf")
{
p = parse_tf(file, tf_options{is_nhwc, batch}); p = parse_tf(file, tf_options{is_nhwc, batch});
} }
}
else else
{ {
if(model == "resnet50") if(model == "resnet50")
......
...@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,8 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser /// struct to pass in onnx options to parser
struct onnx_options struct onnx_options
{ {
/// default batch size to use (if not specified in onnx file)
std::size_t default_dim_value = 1; std::size_t default_dim_value = 1;
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false;
/// Print program if an error occurs
bool print_program_on_error = false;
}; };
/// Create a program from an onnx file /// Create a program from an onnx file
......
...@@ -38,6 +38,7 @@ struct onnx_parser ...@@ -38,6 +38,7 @@ struct onnx_parser
bool is_pytorch = false; bool is_pytorch = false;
std::size_t default_dim_value = 1; std::size_t default_dim_value = 1;
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
bool skip_unknown_operators = false;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
std::unordered_map<std::string, operation> map_actv_funcs; std::unordered_map<std::string, operation> map_actv_funcs;
...@@ -1981,7 +1982,10 @@ struct onnx_parser ...@@ -1981,7 +1982,10 @@ struct onnx_parser
std::size_t output_num = static_cast<std::size_t>(node.output().size()); std::size_t output_num = static_cast<std::size_t>(node.output().size());
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
if(skip_unknown_operators)
result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args)); result.push_back(prog.add_instruction(op::unknown{node.op_type()}, args));
else
MIGRAPHX_THROW("Unknown operator: " + node.op_type());
} }
else else
{ {
...@@ -2237,8 +2241,10 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -2237,8 +2241,10 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
onnx_parser parser; onnx_parser parser;
parser.map_input_dims = options.map_input_dims; parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value; parser.default_dim_value = options.default_dim_value;
parser.skip_unknown_operators = options.skip_unknown_operators;
#ifndef NDEBUG if(options.print_program_on_error)
{
// Log the program when it can't be parsed // Log the program when it can't be parsed
try try
{ {
...@@ -2249,9 +2255,11 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -2249,9 +2255,11 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
std::cerr << parser.prog << std::endl; std::cerr << parser.prog << std::endl;
throw; throw;
} }
#else }
else
{
parser.parse_from(std::forward<Ts>(xs)...); parser.parse_from(std::forward<Ts>(xs)...);
#endif }
return std::move(parser.prog); return std::move(parser.prog);
} }
......
...@@ -192,17 +192,23 @@ PYBIND11_MODULE(migraphx, m) ...@@ -192,17 +192,23 @@ PYBIND11_MODULE(migraphx, m)
m.def("parse_onnx", m.def("parse_onnx",
[](const std::string& filename, [](const std::string& filename,
unsigned int default_dim_value,
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims, std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims,
std::size_t value) { bool skip_unknown_operators,
bool print_program_on_error) {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.map_input_dims = map_input_dims; options.map_input_dims = map_input_dims;
options.default_dim_value = value; options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
return migraphx::parse_onnx(filename, options); return migraphx::parse_onnx(filename, options);
}, },
"Parse onnx file", "Parse onnx file",
py::arg("filename"), py::arg("filename"),
py::arg("map_input_dims") = std::map<std::string, std::vector<std::size_t>>(), py::arg("default_dim_value") = 1,
py::arg("value") = 1); py::arg("map_input_dims") = std::unordered_map<std::string, std::vector<std::size_t>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false);
m.def("get_target", [](const std::string& name) -> migraphx::target { m.def("get_target", [](const std::string& name) -> migraphx::target {
if(name == "cpu") if(name == "cpu")
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = false) migraphx::program optimize_onnx(const std::string& name, bool eliminate_deadcode = false)
{ {
auto prog = migraphx::parse_onnx(name); migraphx::onnx_options options;
options.skip_unknown_operators = true;
auto prog = migraphx::parse_onnx(name, options);
if(eliminate_deadcode) if(eliminate_deadcode)
migraphx::run_passes(prog, {migraphx::dead_code_elimination{}}); migraphx::run_passes(prog, {migraphx::dead_code_elimination{}});
...@@ -1717,6 +1719,18 @@ TEST_CASE(unknown_test) ...@@ -1717,6 +1719,18 @@ TEST_CASE(unknown_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(unknown_test_throw)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx"); }));
}
TEST_CASE(unknown_test_throw_print_error)
{
migraphx::onnx_options options;
options.print_program_on_error = true;
EXPECT(test::throws([&] { migraphx::parse_onnx("unknown_test.onnx", options); }));
}
TEST_CASE(variable_batch_test) TEST_CASE(variable_batch_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment