Unverified Commit c5eee1a3 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Convert a fully fixed map_dyn_input_dims value to a static shape when parsing ONNX (#1682)

Fixes the above behavior
This needs to be changed to allow for setting static shapes with map_dyn_input_dims since you cannot also use map_input_dims
parent 42685803
...@@ -39,6 +39,20 @@ ...@@ -39,6 +39,20 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
static shape shape_from_dyn_dims(shape::type_t shape_type,
const std::vector<shape::dynamic_dimension>& dyn_dims)
{
if(std::all_of(dyn_dims.begin(), dyn_dims.end(), [](auto dd) { return dd.is_fixed(); }))
{
std::vector<std::size_t> dims;
std::transform(dyn_dims.cbegin(), dyn_dims.cend(), std::back_inserter(dims), [](auto d) {
return d.max;
});
return {shape_type, dims};
}
return {shape_type, dyn_dims};
}
namespace onnx { namespace onnx {
static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node)
...@@ -300,7 +314,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini ...@@ -300,7 +314,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini
else if(map_dyn_input_dims.count(name) > 0) else if(map_dyn_input_dims.count(name) > 0)
{ {
shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); shape::type_t shape_type = get_type(input.type().tensor_type().elem_type());
s = {shape_type, map_dyn_input_dims.at(name)}; s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name));
} }
else else
{ {
...@@ -503,16 +517,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, ...@@ -503,16 +517,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t,
{ {
return {shape_type}; return {shape_type};
} }
if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); })) return shape_from_dyn_dims(shape_type, dynamic_dims);
{
std::vector<std::size_t> dims;
std::transform(dynamic_dims.begin(),
dynamic_dims.end(),
std::back_inserter(dims),
[](auto d) { return d.max; });
return {shape_type, dims};
}
return {shape_type, dynamic_dims};
} }
shape::type_t get_type(int dtype) shape::type_t get_type(int dtype)
......
...@@ -4959,13 +4959,13 @@ TEST_CASE(reducemax_dyn_test) ...@@ -4959,13 +4959,13 @@ TEST_CASE(reducemax_dyn_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l0 = mm->add_parameter( auto l0 = mm->add_parameter(
"x", migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}, {6, 6}}}); "x", migraphx::shape{migraphx::shape::float_type, {{3, 5}, {4, 4}, {5, 5}, {6, 6}}});
auto r0 = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), l0); auto r0 = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), l0);
auto r1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), r0); auto r1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), r0);
mm->add_return({r1}); mm->add_return({r1});
migraphx::onnx_options options; migraphx::onnx_options options;
options.map_dyn_input_dims["x"] = {{3, 3}, {4, 4}, {5, 5}, {6, 6}}; options.map_dyn_input_dims["x"] = {{3, 5}, {4, 4}, {5, 5}, {6, 6}};
auto prog = migraphx::parse_onnx("reducemax_dyn_test.onnx", options); auto prog = migraphx::parse_onnx("reducemax_dyn_test.onnx", options);
EXPECT(p == prog); EXPECT(p == prog);
...@@ -6953,6 +6953,23 @@ TEST_CASE(variable_batch_user_input_test6) ...@@ -6953,6 +6953,23 @@ TEST_CASE(variable_batch_user_input_test6)
EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); })); EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); }));
} }
TEST_CASE(variable_batch_user_input_test7)
{
// if entry in map_dyn_input_dims is all fixed dynamic_dimensions, convert it to a static shape
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}});
auto r = mm->add_instruction(migraphx::make_op("identity"), l0);
mm->add_return({r});
migraphx::onnx_options options;
options.map_dyn_input_dims["0"] = {{2, 2, {2}}, {3, 3}, {16, 16}, {16, 16}};
auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(variable_batch_leq_zero_test) TEST_CASE(variable_batch_leq_zero_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment