Commit abf62c5d authored by charlie's avatar charlie
Browse files

Handle dynamic graph input shapes in ONNX parser

parent faefeef9
......@@ -10,10 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser
struct onnx_options
{
/// default batch size to use (if not specified in onnx file)
std::size_t default_dim_value = 1;
/// Explicitly specify the dims of an input
/// Old way to set default fixed dimension size (priority over default_dyn_dim_value)
std::size_t default_dim_value = 0;
/// Default dynamic dimension size (if not specified in onnx file)
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
/// Explicitly specify the dims of an input (priority over map_dyn_input_dims)
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/// Explicitly specify dynamic dims of an input
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false;
/// Print program if an error occurs
......
......@@ -71,8 +71,9 @@ struct onnx_parser
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
std::size_t default_dim_value = 1;
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t opset_version = 13;
......
......@@ -19,7 +19,16 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{
onnx::onnx_parser parser;
parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value;
parser.map_dyn_input_dims = options.map_dyn_input_dims;
auto dim_val = options.default_dim_value;
if (dim_val != 0)
{
parser.default_dyn_dim_value = {dim_val, dim_val, 0};
}
else
{
parser.default_dyn_dim_value = options.default_dyn_dim_value;
}
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
......
......@@ -12,9 +12,11 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp>
#include <migraphx/env.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
......@@ -245,7 +247,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// input not in initializer_data, so it is a real input
if(!contains(mod_insts, name))
{
// ONNX specification does not specify hwo to deal with the
// ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that.
......@@ -255,13 +257,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
"\" existing in parent graph!");
}
shape s;
std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0)
{
dims = map_input_dims.at(name);
s = parse_type(input.type(), dims);
}
else if(map_dyn_input_dims.count(name) > 0)
{
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)};
}
else
{
s = parse_type(input.type(), dims);
}
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s);
}
}
......@@ -416,30 +427,45 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return {shape_type, input_dims};
}
std::vector<std::size_t> dims;
std::vector<shape::dynamic_dimension> dynamic_dims;
auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(),
tensor_dims.end(),
std::back_inserter(dims),
[&](auto&& d) -> std::size_t {
std::back_inserter(dynamic_dims),
[&](auto&& d) -> shape::dynamic_dimension {
if(d.has_dim_value())
{
if(static_cast<int>(d.dim_value()) <= 0)
{
return default_dim_value;
return default_dyn_dim_value;
}
return d.dim_value();
auto tmp = static_cast<std::size_t>(d.dim_value());
return {tmp, tmp, 0};
}
else
{
return default_dim_value;
return default_dyn_dim_value;
}
});
if(dims.empty())
if(dynamic_dims.empty())
{
return {shape_type};
}
if (std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(
dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) {
return d.max;
}
);
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
}
shape::type_t get_type(int dtype)
......
......@@ -5411,7 +5411,55 @@ TEST_CASE(variable_batch_test)
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test)
TEST_CASE(variable_batch_user_input_test1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 2, 0};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 5, 0};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test4)
{
migraphx::program p;
auto* mm = p.get_main_module();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment