Commit b74d3a8f authored by Brian Pickrell's avatar Brian Pickrell
Browse files

responses to some PR comments, new parsing tests, and removed an IF block that...

responses to some PR comments, new parsing tests, and removed an IF block that blocked a needed check
parent 1c3be168
...@@ -299,14 +299,14 @@ struct parse_resize : op_parser<parse_resize> ...@@ -299,14 +299,14 @@ struct parse_resize : op_parser<parse_resize>
// input data shape info. Convert static lens to dynamic to simplify referencing them later // input data shape info. Convert static lens to dynamic to simplify referencing them later
auto in_s = args[0]->get_shape().to_dynamic(); auto in_s = args[0]->get_shape().to_dynamic();
if(in_s.ndim() < 2) if(args[0]->get_shape().dynamic() and in_s.ndim() < 2)
MIGRAPHX_THROW( MIGRAPHX_THROW(
"PARSE_" + opd.op_name + "PARSE_" + opd.op_name +
": requires 2 or more dimensions input, where first dimension is batch #"); ": requires 2 or more dimensions input, where first dimension is batch #");
std::vector<migraphx::shape::dynamic_dimension> in_dims = in_s.dyn_dims(); std::vector<migraphx::shape::dynamic_dimension> in_dims = in_s.dyn_dims();
// output shape is explicitly specified // output shape is explicitly specified
std::vector<size_t> out_lens(in_dims.size()); std::vector<size_t> out_lens(in_s.ndim());
// scale // scale
std::vector<double> vec_scale; std::vector<double> vec_scale;
...@@ -339,14 +339,14 @@ struct parse_resize : op_parser<parse_resize> ...@@ -339,14 +339,14 @@ struct parse_resize : op_parser<parse_resize>
out_lens.clear(); out_lens.clear();
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); }); arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_dims.size()) if(out_lens.size() != in_s.ndim())
{ {
MIGRAPHX_THROW("PARSE_" + opd.op_name + MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output rank does not match input rank"); ": specified output rank does not match input rank");
} }
// compute the scale in each dimension // compute the scale in each dimension
vec_scale.resize(in_dims.size()); vec_scale.resize(in_s.ndim());
std::transform(in_dims.begin(), std::transform(in_dims.begin(),
in_dims.end(), in_dims.end(),
...@@ -358,29 +358,25 @@ struct parse_resize : op_parser<parse_resize> ...@@ -358,29 +358,25 @@ struct parse_resize : op_parser<parse_resize>
else else
{ {
// scale input // scale input
if(lens[0] == in_dims.size()) auto arg_scale = arg->eval();
{ check_arg_empty(arg_scale,
auto arg_scale = arg->eval(); "PARSE_" + opd.op_name + ": dynamic input scale is not supported!");
check_arg_empty(arg_scale,
"PARSE_" + opd.op_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_dims.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified scale rank does not match input rank");
}
std::transform(in_dims.begin(), arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
in_dims.end(), if(in_dims.size() != vec_scale.size())
vec_scale.begin(), {
out_lens.begin(), MIGRAPHX_THROW("PARSE_" + opd.op_name +
[&](auto idx, auto scale) { ": specified scale rank does not match input rank");
// inferred output size is floor(idx.max * scale)
return idx.max * scale;
});
} }
std::transform(in_dims.begin(),
in_dims.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
// inferred output size is floor(idx.max * scale)
return idx.max * scale;
});
break; break;
} }
} }
......
...@@ -6669,6 +6669,69 @@ def resize_upsample_f_dyn_test(): ...@@ -6669,6 +6669,69 @@ def resize_upsample_f_dyn_test():
return ([node], [X], [Y], [scale_tensor]) return ([node], [X], [Y], [scale_tensor])
@onnx_test()
def resize_dyn_err1_test():
scales = np.array([1.0, 1.0, 1.601, 1.601], dtype=np.float32)
scale_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, None, 3, 5])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
coordinate_transformation_mode='half_pixel',
mode='nearest',
nearest_mode='round_prefer_ceil')
return ([node], [X], [Y], [scale_tensor])
@onnx_test()
def resize_dyn_err2_test():
scales = np.array([1.601], dtype=np.float32)
scale_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
coordinate_transformation_mode='half_pixel',
mode='nearest',
nearest_mode='round_prefer_ceil')
return ([node], [X], [Y], [scale_tensor])
@onnx_test()
def resize_dyn_err3_test():
scales = np.array([1.601], dtype=np.float32)
scale_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 3])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
coordinate_transformation_mode='half_pixel',
mode='nearest',
nearest_mode='round_prefer_ceil')
return ([node], [X], [Y], [scale_tensor])
@onnx_test() @onnx_test()
def resize_downsample_c_test(): def resize_downsample_c_test():
scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32)
......
...@@ -6300,6 +6300,36 @@ TEST_CASE(resize_downsample_f_dyn_test) ...@@ -6300,6 +6300,36 @@ TEST_CASE(resize_downsample_f_dyn_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(resize_dyn_err1_test)
{
// wrong dimension is dynamic
migraphx::shape::dynamic_dimension dd{1, 10};
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
EXPECT(test::throws([&] { migraphx::parse_onnx("resize_dyn_err1_test.onnx", options); }));
}
TEST_CASE(resize_dyn_err2_test)
{
// 1-d dynamic input
migraphx::shape::dynamic_dimension dd{1, 10};
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
EXPECT(test::throws([&] { migraphx::parse_onnx("resize_dyn_err2_test.onnx", options); }));
}
TEST_CASE(resize_dyn_err3_test)
{
// dimensions of input and scales don't match
migraphx::shape::dynamic_dimension dd{1, 10};
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
EXPECT(test::throws([&] { migraphx::parse_onnx("resize_dyn_err3_test.onnx", options); }));
}
TEST_CASE(resize_downsample_linear_test) TEST_CASE(resize_downsample_linear_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -1791,7 +1791,6 @@ TEST_CASE(resize_downsample_f_dyn_test) ...@@ -1791,7 +1791,6 @@ TEST_CASE(resize_downsample_f_dyn_test)
{ {
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10}; options.default_dyn_dim_value = {1, 10};
options.use_dyn_output = true;
auto p = migraphx::parse_onnx("resize_downsample_f_dyn_test.onnx", options); auto p = migraphx::parse_onnx("resize_downsample_f_dyn_test.onnx", options);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
...@@ -1838,7 +1837,6 @@ TEST_CASE(resize_upsample_f_dyn_test) ...@@ -1838,7 +1837,6 @@ TEST_CASE(resize_upsample_f_dyn_test)
// resize with half_pixel and round_prefer_ceil, with scale > 1 // resize with half_pixel and round_prefer_ceil, with scale > 1
migraphx::onnx_options options; migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10}; options.default_dyn_dim_value = {1, 10};
options.use_dyn_output = true;
auto p = migraphx::parse_onnx("resize_upsample_f_dyn_test.onnx", options); auto p = migraphx::parse_onnx("resize_upsample_f_dyn_test.onnx", options);
p.compile(migraphx::make_target("ref")); p.compile(migraphx::make_target("ref"));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment