"script/clang-format-overwrite.sh" did not exist on "6260ced2f3a4d9a2a832563905135c01ba72b56b"
Unverified Commit 42230bbd authored by Zakor Gyula's avatar Zakor Gyula Committed by GitHub
Browse files

Limit max loop iterations (#2361)

parent 90b6b323
......@@ -326,7 +326,7 @@ op
parse_onnx
----------
.. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false, max_loop_iterations=10)
.. py:function:: parse_onnx(filename, default_dim_value=1, map_input_dims={}, skip_unknown_operators=false, print_program_on_error=false, max_loop_iterations=10, limit_max_iterations=65535)
Load and parse an onnx file.
......@@ -337,7 +337,8 @@ parse_onnx
:param list[dynamic_dimension] map_dyn_input_dims: Explicitly specify the dynamic_dimensions of an input.
:param str skip_unknown_operators: Continue parsing onnx file if an unknown operator is found.
:param str print_program_on_error: Print program if an error occurs.
:param int max_loop_iterations: Maximum iteration number for the loop operator.
:param int max_loop_iterations: Maximum iteration number for the loop operator if trip count is not set.
:param int limit_max_iterations: Maximum iteration limit for the loop operator.
:rtype: program
parse_tf
......
......@@ -164,6 +164,11 @@ void set_default_loop_iterations(onnx_options& options, int64_t value)
options.max_loop_iterations = value;
}
void set_limit_loop_iterations(onnx_options& options, int64_t value)
{
options.limit_max_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
......@@ -1904,6 +1909,17 @@ migraphx_onnx_options_set_default_loop_iterations(migraphx_onnx_options_t onnx_o
return api_error_result;
}
extern "C" migraphx_status
migraphx_onnx_options_set_limit_loop_iterations(migraphx_onnx_options_t onnx_options, int64_t value)
{
auto api_error_result = migraphx::try_([&] {
if(onnx_options == nullptr)
MIGRAPHX_THROW(migraphx_status_bad_param, "Bad parameter onnx_options: Null pointer");
migraphx::set_limit_loop_iterations((onnx_options->object), (value));
});
return api_error_result;
}
extern "C" migraphx_status migraphx_file_options_destroy(migraphx_file_options_t file_options)
{
auto api_error_result = migraphx::try_([&] { destroy((file_options)); });
......
......@@ -514,6 +514,9 @@ MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_dyn_dim_valu
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_default_loop_iterations(
migraphx_onnx_options_t onnx_options, int64_t value);
MIGRAPHX_C_EXPORT migraphx_status migraphx_onnx_options_set_limit_loop_iterations(
migraphx_onnx_options_t onnx_options, int64_t value);
MIGRAPHX_C_EXPORT migraphx_status
migraphx_file_options_destroy(migraphx_file_options_t file_options);
......
......@@ -1321,6 +1321,12 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
call(&migraphx_onnx_options_set_default_loop_iterations, this->get_handle_ptr(), value);
}
/// Set max iteration limit for the loop operator
void set_limit_loop_iterations(int64_t value)
{
call(&migraphx_onnx_options_set_limit_loop_iterations, this->get_handle_ptr(), value);
}
};
/// Parse an onnx file into a migraphx program
......
......@@ -349,6 +349,11 @@ def onnx_options(h):
api.params(value='int64_t'),
invoke='migraphx::set_default_loop_iterations($@)',
)
h.method(
'set_limit_loop_iterations',
api.params(value='int64_t'),
invoke='migraphx::set_limit_loop_iterations($@)',
)
@auto_handle()
......
......@@ -48,8 +48,12 @@ struct onnx_options
bool skip_unknown_operators = false;
/// Print program if an error occurs
bool print_program_on_error = false;
/// Max iter num for the loop operator
/// Max iter num for the loop operator if trip count is not set
int64_t max_loop_iterations = 10;
/// Max iter limit for the loop operator.
/// Since loop will become a tensor of max iter size a huge number can cause overflow during
/// shape computations.
int64_t limit_max_iterations = std::numeric_limits<uint16_t>::max();
/// Use dynamic output for operators when available
bool use_dyn_output = false;
};
......
......@@ -97,10 +97,11 @@ struct onnx_parser
shape::dynamic_dimension default_dyn_dim_value = {1, 1};
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
std::unordered_map<std::string, std::vector<shape::dynamic_dimension>> map_dyn_input_dims;
bool use_dyn_output = false;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t opset_version = 13;
bool use_dyn_output = false;
bool skip_unknown_operators = false;
int64_t max_loop_iterations = 10;
int64_t limit_max_iterations = std::numeric_limits<uint16_t>::max();
int64_t opset_version = 13;
std::unordered_map<std::string, op_func> ops;
......
......@@ -67,6 +67,7 @@ program parse_onnx_from(const onnx_options& options, Ts&&... xs)
}
parser.skip_unknown_operators = options.skip_unknown_operators;
parser.max_loop_iterations = options.max_loop_iterations;
parser.limit_max_iterations = options.limit_max_iterations;
parser.use_dyn_output = options.use_dyn_output;
if(options.print_program_on_error)
......
......@@ -58,6 +58,16 @@ struct parse_loop : op_parser<parse_loop>
}
}
// cap max_iter because loop uses static shapes with max_iter size and huge numbers
// here can cause overflow
if(max_iterations > parser.limit_max_iterations)
{
std::cerr << "WARNING: PARSE_LOOP max_iterations exceeds the maximum loop "
"iterations limit, it will be changed from "
<< max_iterations << " to " << parser.limit_max_iterations << ".\n";
max_iterations = parser.limit_max_iterations;
}
// condition input is empty
if(args.at(1)->name() == "undefined")
{
......
......@@ -472,7 +472,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
map_dyn_input_dims,
bool skip_unknown_operators,
bool print_program_on_error,
int64_t max_loop_iterations) {
int64_t max_loop_iterations,
int64_t limit_max_iterations) {
migraphx::onnx_options options;
options.default_dim_value = default_dim_value;
options.default_dyn_dim_value = default_dyn_dim_value;
......@@ -481,6 +482,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
options.skip_unknown_operators = skip_unknown_operators;
options.print_program_on_error = print_program_on_error;
options.max_loop_iterations = max_loop_iterations;
options.limit_max_iterations = limit_max_iterations;
return migraphx::parse_onnx(filename, options);
},
"Parse onnx file",
......@@ -492,7 +494,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
std::unordered_map<std::string, std::vector<migraphx::shape::dynamic_dimension>>(),
py::arg("skip_unknown_operators") = false,
py::arg("print_program_on_error") = false,
py::arg("max_loop_iterations") = 10);
py::arg("max_loop_iterations") = 10,
py::arg("limit_max_iterations") = std::numeric_limits<uint16_t>::max());
m.def(
"parse_onnx_buffer",
......
......@@ -198,4 +198,29 @@ TEST_CASE(set_loop_default_iter_num)
EXPECT(out_shapes[1].lengths() == out_lens1);
}
TEST_CASE(set_loop_limit_iterations)
{
migraphx::onnx_options option;
option.set_default_loop_iterations(15);
option.set_limit_loop_iterations(10);
auto p = migraphx::parse_onnx("loop_default_test.onnx", option);
auto out_shapes = p.get_output_shapes();
std::vector<std::size_t> out_lens0 = {1};
EXPECT(out_shapes[0].lengths() == out_lens0);
std::vector<std::size_t> out_lens1 = {10, 1};
EXPECT(out_shapes[1].lengths() == out_lens1);
}
TEST_CASE(set_loop_limit_iterations2)
{
migraphx::onnx_options option;
option.set_limit_loop_iterations(10);
auto p = migraphx::parse_onnx("loop_test_implicit_tripcnt.onnx", option);
auto out_shapes = p.get_output_shapes();
std::vector<std::size_t> out_lens0 = {1};
EXPECT(out_shapes[0].lengths() == out_lens0);
std::vector<std::size_t> out_lens1 = {10, 1};
EXPECT(out_shapes[1].lengths() == out_lens1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -317,4 +317,59 @@ TEST_CASE(loop_test)
}
}
TEST_CASE(loop_test_limit_max_iter)
{
auto run_prog = [&](int64_t limit_max_iterations) {
migraphx::onnx_options parse_options;
parse_options.set_limit_loop_iterations(limit_max_iterations);
auto p = migraphx::parse_onnx("loop_test_implicit_tripcnt.onnx", parse_options);
auto shapes_before = p.get_output_shapes();
migraphx::compile_options options;
options.set_offload_copy();
p.compile(migraphx::target("gpu"), options);
auto shapes_after = p.get_output_shapes();
CHECK(shapes_before.size() == 2);
CHECK(bool{shapes_before.front() == shapes_after.front()});
migraphx::program_parameters pp;
auto param_shapes = p.get_parameter_shapes();
auto aas = param_shapes["a"];
std::vector<float> xd = {1.0f};
pp.add("a", migraphx::argument(aas, xd.data()));
auto bbs = param_shapes["b"];
std::vector<float> yd = {2.0};
pp.add("b", migraphx::argument(bbs, yd.data()));
auto cs = param_shapes["keep_going_cond"];
bool cond = true;
pp.add("keep_going_cond", migraphx::argument(cs, &cond));
auto outputs = p.eval(pp);
auto output = outputs[0];
std::vector<std::vector<float>> ret;
ret.push_back(output.as_vector<float>());
output = outputs[1];
ret.push_back(output.as_vector<float>());
return ret;
};
{
auto result_vector = run_prog(5);
std::vector<float> gold0 = {2.0f};
EXPECT(result_vector.at(0) == gold0);
std::vector<float> gold1 = {-2, 4, 0, 0, 0};
EXPECT(result_vector.at(1) == gold1);
}
{
auto result_vector = run_prog(20);
std::vector<float> gold0 = {2.0f};
EXPECT(result_vector.at(0) == gold0);
std::vector<float> gold1 = {-2, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
EXPECT(result_vector.at(1) == gold1);
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -4334,6 +4334,50 @@ def loop_test():
return ([node], [iter, cond, a, b], [b_loop, uout])
@onnx_test()
def loop_test_implicit_tripcnt():
body = helper.make_graph([
helper.make_node("Add", ["a", "b_in"], ["my_local"]),
helper.make_node("Sub", ["a", "b_in"], ["a_sub_b_in"]),
helper.make_node("Greater", ["my_local", "a_sub_b_in"],
["keep_going"]),
helper.make_node("Add", ["a_sub_b_in", "a_sub_b_in"],
["user_defined_vals"]),
], "body", [
helper.make_tensor_value_info('iteration_num', TensorProto.INT64, [1]),
helper.make_tensor_value_info('keep_going_inp', TensorProto.BOOL, [1]),
helper.make_tensor_value_info('b_in', TensorProto.FLOAT, [1])
], [
helper.make_tensor_value_info('keep_going', TensorProto.BOOL, [1]),
helper.make_tensor_value_info('a_sub_b_in', TensorProto.FLOAT, [1]),
helper.make_tensor_value_info('my_local', TensorProto.FLOAT, [1]),
helper.make_tensor_value_info('user_defined_vals', TensorProto.FLOAT,
[1]),
])
iter = helper.make_tensor(name='max_trip_count',
data_type=TensorProto.INT64,
dims=[1],
vals=[15])
node = helper.make_node(
"Loop",
inputs=["max_trip_count", "keep_going_cond", "b"],
outputs=["b_loop", "my_local_loop", "user_defined_vals_loop"],
body=body)
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [1])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [1])
cond = helper.make_tensor_value_info('keep_going_cond', TensorProto.BOOL,
[1])
b_loop = helper.make_tensor_value_info('b_loop', TensorProto.FLOAT, [1])
uout = helper.make_tensor_value_info('user_defined_vals_loop',
TensorProto.FLOAT, [2, 1])
return ([node], [cond, a, b], [b_loop, uout], [iter])
@onnx_test()
def lpnormalization_axis_error_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 3])
......
 loop_test_implicit_tripcnt:

max_trip_count
keep_going_cond
bb_loop my_local_loopuser_defined_vals_loop"Loop*
body2

a
b_inmy_local"Add

a
b_in
a_sub_b_in"Sub
+
my_local
a_sub_b_in
keep_going"Greater
0
a_sub_b_in
a_sub_b_inuser_defined_vals"AddbodyZ
iteration_num

Z
keep_going_inp
 
Z
b_in

b
keep_going
 
b
a_sub_b_in

b
my_local

b
user_defined_vals

loop_test_implicit_tripcnt*:Bmax_trip_countZ
keep_going_cond
 
Z
a

Z
b

b
b_loop

b(
user_defined_vals_loop


B
\ No newline at end of file
......@@ -164,6 +164,11 @@ void set_default_loop_iterations(onnx_options& options, int64_t value)
options.max_loop_iterations = value;
}
void set_limit_loop_iterations(onnx_options& options, int64_t value)
{
options.limit_max_iterations = value;
}
void set_nhwc(tf_options& options, bool is_nhwc) { options.is_nhwc = is_nhwc; }
void set_default_dim_value(tf_options& options, size_t value) { options.batch_size = value; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment