Commit b74d3a8f authored by Brian Pickrell's avatar Brian Pickrell
Browse files

responses to some PR comments, new parsing tests, and removed an IF block that...

responses to some PR comments, new parsing tests, and removed an IF block that blocked a needed check
parent 1c3be168
......@@ -299,14 +299,14 @@ struct parse_resize : op_parser<parse_resize>
// input data shape info. Convert static lens to dynamic to simplify referencing them later
auto in_s = args[0]->get_shape().to_dynamic();
if(in_s.ndim() < 2)
if(args[0]->get_shape().dynamic() and in_s.ndim() < 2)
MIGRAPHX_THROW(
"PARSE_" + opd.op_name +
": requires 2 or more dimensions input, where first dimension is batch #");
std::vector<migraphx::shape::dynamic_dimension> in_dims = in_s.dyn_dims();
// output shape is explicitly specified
std::vector<size_t> out_lens(in_dims.size());
std::vector<size_t> out_lens(in_s.ndim());
// scale
std::vector<double> vec_scale;
......@@ -339,14 +339,14 @@ struct parse_resize : op_parser<parse_resize>
out_lens.clear();
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_dims.size())
if(out_lens.size() != in_s.ndim())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output rank does not match input rank");
}
// compute the scale in each dimension
vec_scale.resize(in_dims.size());
vec_scale.resize(in_s.ndim());
std::transform(in_dims.begin(),
in_dims.end(),
......@@ -358,29 +358,25 @@ struct parse_resize : op_parser<parse_resize>
else
{
// scale input
if(lens[0] == in_dims.size())
{
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_" + opd.op_name +
": dynamic input scale is not supported!");
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_dims.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified scale rank does not match input rank");
}
auto arg_scale = arg->eval();
check_arg_empty(arg_scale,
"PARSE_" + opd.op_name + ": dynamic input scale is not supported!");
std::transform(in_dims.begin(),
in_dims.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
// inferred output size is floor(idx.max * scale)
return idx.max * scale;
});
arg_scale.visit([&](auto v) { vec_scale.assign(v.begin(), v.end()); });
if(in_dims.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified scale rank does not match input rank");
}
std::transform(in_dims.begin(),
in_dims.end(),
vec_scale.begin(),
out_lens.begin(),
[&](auto idx, auto scale) {
// inferred output size is floor(idx.max * scale)
return idx.max * scale;
});
break;
}
}
......
......@@ -6669,6 +6669,69 @@ def resize_upsample_f_dyn_test():
return ([node], [X], [Y], [scale_tensor])
@onnx_test()
def resize_dyn_err1_test():
scales = np.array([1.0, 1.0, 1.601, 1.601], dtype=np.float32)
scale_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, None, 3, 5])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
coordinate_transformation_mode='half_pixel',
mode='nearest',
nearest_mode='round_prefer_ceil')
return ([node], [X], [Y], [scale_tensor])
@onnx_test()
def resize_dyn_err2_test():
scales = np.array([1.601], dtype=np.float32)
scale_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
coordinate_transformation_mode='half_pixel',
mode='nearest',
nearest_mode='round_prefer_ceil')
return ([node], [X], [Y], [scale_tensor])
@onnx_test()
def resize_dyn_err3_test():
scales = np.array([1.601], dtype=np.float32)
scale_tensor = helper.make_tensor(name='scales',
data_type=TensorProto.FLOAT,
dims=scales.shape,
vals=scales.flatten().astype(np.float32))
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [None, 3])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [])
node = onnx.helper.make_node('Resize',
inputs=['X', '', 'scales'],
outputs=['Y'],
coordinate_transformation_mode='half_pixel',
mode='nearest',
nearest_mode='round_prefer_ceil')
return ([node], [X], [Y], [scale_tensor])
@onnx_test()
def resize_downsample_c_test():
scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32)
......
......@@ -6300,6 +6300,36 @@ TEST_CASE(resize_downsample_f_dyn_test)
EXPECT(p == prog);
}
TEST_CASE(resize_dyn_err1_test)
{
// wrong dimension is dynamic
migraphx::shape::dynamic_dimension dd{1, 10};
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
EXPECT(test::throws([&] { migraphx::parse_onnx("resize_dyn_err1_test.onnx", options); }));
}
TEST_CASE(resize_dyn_err2_test)
{
// 1-d dynamic input
migraphx::shape::dynamic_dimension dd{1, 10};
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
EXPECT(test::throws([&] { migraphx::parse_onnx("resize_dyn_err2_test.onnx", options); }));
}
TEST_CASE(resize_dyn_err3_test)
{
// dimensions of input and scales don't match
migraphx::shape::dynamic_dimension dd{1, 10};
migraphx::onnx_options options;
options.default_dyn_dim_value = dd;
EXPECT(test::throws([&] { migraphx::parse_onnx("resize_dyn_err3_test.onnx", options); }));
}
TEST_CASE(resize_downsample_linear_test)
{
migraphx::program p;
......
......@@ -1791,7 +1791,6 @@ TEST_CASE(resize_downsample_f_dyn_test)
{
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10};
options.use_dyn_output = true;
auto p = migraphx::parse_onnx("resize_downsample_f_dyn_test.onnx", options);
p.compile(migraphx::make_target("ref"));
......@@ -1838,7 +1837,6 @@ TEST_CASE(resize_upsample_f_dyn_test)
// resize with half_pixel and round_prefer_ceil, with scale > 1
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10};
options.use_dyn_output = true;
auto p = migraphx::parse_onnx("resize_upsample_f_dyn_test.onnx", options);
p.compile(migraphx::make_target("ref"));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment