Commit 4b1b8032 authored by charlie's avatar charlie
Browse files

Merge branch 'dyn_dim_onnx_parser' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_conv

parents f05785b2 7f1386b3
...@@ -10,10 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,10 +10,14 @@ inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser /// struct to pass in onnx options to parser
struct onnx_options struct onnx_options
{ {
/// default batch size to use (if not specified in onnx file) /// Old way to set default fixed dimension size (priority over default_dyn_dim_value)
std::size_t default_dim_value = 1; std::size_t default_dim_value = 0;
/// Explicitly specify the dims of an input /// Default dynamic dimension size (if not specified in onnx file)
shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
/// Explicitly specify the dims of an input (priority over map_dyn_input_dims)
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {}; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/// Explicitly specify dynamic dims of an input
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found /// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
/// Print program if an error occurs /// Print program if an error occurs
......
...@@ -71,8 +71,9 @@ struct onnx_parser ...@@ -71,8 +71,9 @@ struct onnx_parser
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
std::size_t default_dim_value = 1; shape::dynamic_dimension default_dyn_dim_value = {1, 1, 0};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool skip_unknown_operators = false; bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10; int64_t max_loop_iterations = 10;
int64_t opset_version = 13; int64_t opset_version = 13;
......
...@@ -19,7 +19,16 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs) ...@@ -19,7 +19,16 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
{ {
onnx::onnx_parser parser; onnx::onnx_parser parser;
parser.map_input_dims = options.map_input_dims; parser.map_input_dims = options.map_input_dims;
parser.default_dim_value = options.default_dim_value; parser.map_dyn_input_dims = options.map_dyn_input_dims;
auto dim_val = options.default_dim_value;
if(dim_val != 0)
{
parser.default_dyn_dim_value = {dim_val, dim_val, 0};
}
else
{
parser.default_dyn_dim_value = options.default_dyn_dim_value;
}
parser.skip_unknown_operators = options.skip_unknown_operators; parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations; parser.max_loop_iterations = options.max_loop_iterations;
......
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/filesystem.hpp> #include <migraphx/filesystem.hpp>
#include <migraphx/op/unknown.hpp> #include <migraphx/op/unknown.hpp>
#include <migraphx/env.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace onnx { namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
...@@ -245,7 +247,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -245,7 +247,7 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
// input not in initializer_data, so it is a real input // input not in initializer_data, so it is a real input
if(!contains(mod_insts, name)) if(!contains(mod_insts, name))
{ {
// ONNX specification does not specify hwo to deal with the // ONNX specification does not specify how to deal with the
// scenario that a nested subgraph contains a parameter with the // scenario that a nested subgraph contains a parameter with the
// name existed in its parent graph. // name existed in its parent graph.
// In the current implementation, MIGraphX throws an exception for that. // In the current implementation, MIGraphX throws an exception for that.
...@@ -255,13 +257,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -255,13 +257,22 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
"\" existing in parent graph!"); "\" existing in parent graph!");
} }
shape s;
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
if(map_input_dims.count(name) > 0) if(map_input_dims.count(name) > 0)
{ {
dims = map_input_dims.at(name); dims = map_input_dims.at(name);
s = parse_type(input.type(), dims);
}
else if(map_dyn_input_dims.count(name) > 0)
{
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)};
}
else
{
s = parse_type(input.type(), dims);
} }
shape s = parse_type(input.type(), dims);
mod_insts[name] = mod->add_parameter(name, s); mod_insts[name] = mod->add_parameter(name, s);
} }
} }
...@@ -416,30 +427,41 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -416,30 +427,41 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
return {shape_type, input_dims}; return {shape_type, input_dims};
} }
std::vector<std::size_t> dims; std::vector<shape::dynamic_dimension> dynamic_dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
std::transform(tensor_dims.begin(), std::transform(tensor_dims.begin(),
tensor_dims.end(), tensor_dims.end(),
std::back_inserter(dims), std::back_inserter(dynamic_dims),
[&](auto&& d) -> std::size_t { [&](auto&& d) -> shape::dynamic_dimension {
if(d.has_dim_value()) if(d.has_dim_value())
{ {
if(static_cast<int>(d.dim_value()) <= 0) if(static_cast<int>(d.dim_value()) <= 0)
{ {
return default_dim_value; return default_dyn_dim_value;
} }
return d.dim_value(); auto tmp = static_cast<std::size_t>(d.dim_value());
return {tmp, tmp, 0};
} }
else else
{ {
return default_dim_value; return default_dyn_dim_value;
} }
}); });
if(dims.empty()) if(dynamic_dims.empty())
{
return {shape_type}; return {shape_type};
}
if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims}; return {shape_type, dims};
}
return {shape_type, dynamic_dims};
} }
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
......
...@@ -419,32 +419,17 @@ const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return im ...@@ -419,32 +419,17 @@ const std::vector<shape::dynamic_dimension>& shape::dyn_dims() const { return im
std::vector<std::size_t> shape::min_lens() const std::vector<std::size_t> shape::min_lens() const
{ {
if(not this->dynamic()) return this->dynamic() ? impl->min_lens() : this->lens();
{
return this->lens();
}
return impl->min_lens();
;
} }
std::vector<std::size_t> shape::max_lens() const std::vector<std::size_t> shape::max_lens() const
{ {
if(not this->dynamic()) return this->dynamic() ? impl->max_lens() : this->lens();
{
return this->lens();
}
return impl->max_lens();
;
} }
std::vector<std::size_t> shape::opt_lens() const std::vector<std::size_t> shape::opt_lens() const
{ {
if(not this->dynamic()) return this->dynamic() ? impl->opt_lens() : this->lens();
{
return this->lens();
}
return impl->opt_lens();
;
} }
bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; } bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max; }
......
...@@ -5410,7 +5410,59 @@ TEST_CASE(variable_batch_test) ...@@ -5410,7 +5410,59 @@ TEST_CASE(variable_batch_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(variable_batch_user_input_test) TEST_CASE(variable_batch_user_input_test1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 2, 0};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.default_dyn_dim_value = {2, 5, 0};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0",
migraphx::shape{migraphx::shape::float_type,
{{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 5, 0}, {3, 3, 0}, {16, 16, 0}, {16, 16, 0}};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_user_input_test4)
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment