Unverified Commit 5a87fcbd authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Dynamic dimension input onnx parser (#1249)

Depends on #1199

Adds ONNX parser functionality for dynamic input shapes.
Uses options parameter in parse_onnx()
parent 39b307b2
...@@ -33,10 +33,20 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -33,10 +33,20 @@ inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser /// struct to pass in onnx options to parser
struct onnx_options struct onnx_options
{ {
/// default batch size to use (if not specified in onnx file) /// Old way to set default fixed dimension size
std::size_t default_dim_value = 1; std::size_t default_dim_value = 0;
/*!
* Default dynamic dimension size (if both default_dim_value and default_dyn_dim_value
* set parser throws)
*/
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
/// Explicitly specify the dims of an input /// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/*!
* Explicitly specify dynamic dims of an input (if both map_input_dims and
* map_dyn_input_dims set parser throws)
*/
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found /// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
/// Print program if an error occurs /// Print program if an error occurs
......
...@@ -94,8 +94,9 @@ struct onnx_parser ...@@ -94,8 +94,9 @@ struct onnx_parser
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
std::size_t default_dim_value = 1; shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t opset_version = 13;
......
...@@ -42,7 +42,24 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -42,7 +42,24 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{ {
onnx::onnx_parser parser; onnx::onnx_parser parser;
parser.map_input_dims = options.map_input_dims; parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value; parser.map_dyn_input_dims = options.map_dyn_input_dims;
auto dim_val = options.default_dim_value;
if(dim_val != 0)
{
if(options.default_dyn_dim_value != shape::dynamic_dimension{1, 1, 0})
{
MIGRAPHX_THROW("PARSE_ONNX_FROM: both default_dim_value and default_dyn_dim_value"
"set to non-default value");
}
else
{
parser.default_dyn_dim_value = {dim_val, dim_val, 0};
}
}
else
{
parser.default_dyn_dim_value = options.default_dyn_dim_value;
}
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
......
...@@ -35,9 +35,11 @@ ...@@ -35,9 +35,11 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/env.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
...@@ -255,6 +257,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -255,6 +257,11 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
{ {
if(not map_input_dims.empty() and not map_dyn_input_dims.empty())
{
MIGRAPHX_THROW("PARSE_GRAPH: both map_input_dims and map_dyn_input_dims non-empty, only"
"one should be used");
}
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
{ {
...@@ -268,7 +275,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -268,7 +275,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// input not in initializer_data, so it is a real input // input not in initializer_data, so it is a real input
if(!contains(mod_insts, name)) if(!contains(mod_insts, name))
{ {
// ONNX specification does not specify hwo to deal with the // ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the // scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph. // name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that. // In the current implementation, MIGraphX throws an exception for that.
...@@ -278,13 +285,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -278,13 +285,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
"\" existing in parent graph!"); "\" existing in parent graph!");
} }
shape s;
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0) if(map_input_dims.count(name) > 0)
{ {
dims = map_input_dims.at(name); dims = map_input_dims.at(name);
s = parse_type(input.type(), dims);
}
else if(map_dyn_input_dims.count(name) > 0)
{
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)};
}
else
{
s = parse_type(input.type(), dims);
} }
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
} }
...@@ -439,30 +455,41 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -439,30 +455,41 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return {shape_type, input_dims}; return {shape_type, input_dims};
} }
std::vector<std::size_t> dims; std::vector<shape::dynamic_dimension> dynamic_dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(), std::transform(tensor_dims.begin(),
tensor_dims.end(), tensor_dims.end(),
std::back_inserter(dims), std::back_inserter(dynamic_dims),
[&](auto&& d) -> std::size_t { [&](auto&& d) -> shape::dynamic_dimension {
if(d.has_dim_value()) if(d.has_dim_value())
{ {
if(static_cast<int>(d.dim_value()) <= 0) if(static_cast<int>(d.dim_value()) <= 0)
{ {
return default_dim_value; return default_dyn_dim_value;
} }
return d.dim_value(); std::size_t tmp = d.dim_value();
return {tmp, tmp, 0};
} }
else else
{ {
return default_dim_value; return default_dyn_dim_value;
} }
}); });
if(dims.empty()) if(dynamic_dims.empty())
{
return {shape_type}; return {shape_type};
}
if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims}; return {shape_type, dims};
}
return {shape_type, dynamic_dims};
} }
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
......
...@@ -5437,7 +5437,59 @@ TEST_CASE(variable_batch_test) ...@@ -5437,7 +5437,59 @@ TEST_CASE(variable_batch_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(variable_batch_user_input_test) TEST_CASE(variable_batch_user_input_test1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 2, 0};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 5, 0};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test4)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
...@@ -5453,6 +5505,26 @@ TEST_CASE(variable_batch_user_input_test) ...@@ -5453,6 +5505,26 @@ TEST_CASE(variable_batch_user_input_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(variable_batch_user_input_test5)
{
// Error using default_dim_value and default_dyn_dim_value
migraphx::onnx_options options;
options.default_dim_value = 2;
options.default_dyn_dim_value = {1, 2, 0};
EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); }));
}
TEST_CASE(variable_batch_user_input_test6)
{
// Error using both map_dyn_input_dims and map_input_dims
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}};
options.map_input_dims["0"] = {2, 3, 16, 16};
EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); }));
}
TEST_CASE(variable_batch_leq_zero_test) TEST_CASE(variable_batch_leq_zero_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment