Commit 154e24f2 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

add resize_outsize_test and fix parsing of out_lens

parent c6ddbc45
......@@ -207,9 +207,11 @@ struct parse_resize : op_parser<parse_resize>
}
// input data shape info. Convert static lens to dynamic to simplify referencing them later
auto in_s = args[0]->get_shape().to_dynamic();
auto in_s = args[0]->get_shape().to_dynamic();
if(in_s.ndim() < 2)
MIGRAPHX_THROW("PARSE_" + opd.op_name + ": requires 2 or more dimensions input, where first dimension is batch #");
MIGRAPHX_THROW(
"PARSE_" + opd.op_name +
": requires 2 or more dimensions input, where first dimension is batch #");
std::vector<migraphx::shape::dynamic_dimension> in_dims = in_s.dyn_dims();
// output shape is explicitly specified
......@@ -218,7 +220,7 @@ struct parse_resize : op_parser<parse_resize>
// scale
std::vector<double> vec_scale;
// Infer either output size or scale, depending on input type
// Look at inputs and infer either output size or scale, depending on input type
for(const auto& arg : args)
{
if(arg->name() == "undefined" or arg == args.front())
......@@ -226,13 +228,13 @@ struct parse_resize : op_parser<parse_resize>
continue;
}
// this is just developer code, figure out real requirement
if(arg != args[0] and arg->get_shape().dynamic())
{
MIGRAPHX_THROW("parse_resize: no other dynamic shapes allowed");
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": no dynamic input shapes allowed except the first one");
}
// skipped empty input
// skip any empty inputs
auto lens = arg->get_shape().lens();
if(lens.empty())
{
......@@ -240,31 +242,21 @@ struct parse_resize : op_parser<parse_resize>
}
auto type = arg->get_shape().type();
// output size
// This input is inferred to mean output size if type == int64_type; otherwise
// read it as the scales
if(type == shape::int64_type)
{
auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!");
// reallocate a vector and copy the values to it. All dimensions except batch, even
// if originally dynamic, are required to be fixed so we can refer to their max
// value WLOG.
arg_out_s.visit([&](auto ol) {
// todo: assign doesn't work with dynamic shapes
auto ols = ol.get_shape().to_dynamic();
for(auto it = ols.dyn_dims().begin(); it != ols.dyn_dims().end(); it++)
{
out_lens.push_back(it->max);
}
// out_lens.assign(ol.begin(), ol.end());
});
out_lens.clear();
arg_out_s.visit([&](auto ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_dims.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": specified output size does not match input size");
": specified output rank does not match input rank");
}
// compute the scale in each dimension
......@@ -275,6 +267,7 @@ struct parse_resize : op_parser<parse_resize>
out_lens.begin(),
vec_scale.begin(),
[](auto iss, auto oss) { return double(1.0 * oss / iss.max); });
break;
}
else
{
......@@ -290,7 +283,7 @@ struct parse_resize : op_parser<parse_resize>
if(in_dims.size() != vec_scale.size())
{
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": ranks of input and scale are different!");
": specified scale rank does not match input rank");
}
std::transform(in_dims.begin(),
......@@ -302,6 +295,7 @@ struct parse_resize : op_parser<parse_resize>
return idx.max * scale;
});
}
break;
}
}
......@@ -323,11 +317,9 @@ struct parse_resize : op_parser<parse_resize>
"PARSE_" + opd.op_name +
": dynamic shape inputs other than batch size are not supported");
// TODO: Add support for channel dimension
// take max_lens() to get static dimension set
// Get static dimension set and
// Drop the 0'th dimension,
auto fixed_dims = args[0]->get_shape().max_lens();
auto fixed_dims = args[0]->get_shape().to_static(1).lens();
fixed_dims.erase(fixed_dims.begin());
// dimensions of the (scaled) output, also with the 0'th dimension dropped
auto fixed_out_lens = out_lens;
......@@ -336,19 +328,13 @@ struct parse_resize : op_parser<parse_resize>
// create a shape with the scaled lens and no batch dimension
migraphx::shape static_out_shape(args[0]->get_shape().type(), fixed_out_lens);
// size_t out_elements = std::accumulate(fixed_out_lens.begin(),
// fixed_out_lens.end(),
// std::size_t{1},
// std::multiplies<>());
// std::vector<int> ind(out_elements);
// map out_idx to in_idx
auto idx_op = get_original_idx_op(coord_trans_mode);
auto nearest_op = get_nearest_op(nearest_mode);
// For each element of static_out_shape, find the matching location of input shape.
// The indexes we find will be an argument to the gather op.
shape_for_each(static_out_shape, [&](const auto& out_idx_v, size_t ) {
shape_for_each(static_out_shape, [&](const auto& out_idx_v, size_t) {
std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < fixed_dims.size(); ++ii)
{
......@@ -369,18 +355,21 @@ struct parse_resize : op_parser<parse_resize>
for(auto len : range(fixed_out_lens[ii]))
{
// Convert this index by scaling.
auto idx_val = idx_op(fixed_dims[ii], fixed_out_lens[ii], len, vec_scale[ii+1]);
auto idx_val =
idx_op(fixed_dims[ii], fixed_out_lens[ii], len, vec_scale[ii + 1]);
// round the scaled value to an index
in_idx[len] = nearest_op(fixed_dims[ii], idx_val);
in_idx[len] = nearest_op(fixed_dims[ii], idx_val);
// Put the value into index vector
}
// Create a 1D shape literal
auto index_litA = info.add_literal(literal(migraphx::shape(migraphx::shape::int64_type,
{fixed_out_lens[ii]}), in_idx));
auto index_litA = info.add_literal(
literal(migraphx::shape(migraphx::shape::int64_type, {fixed_out_lens[ii]}),
in_idx));
// add a "gather" instruction
gather_ins = info.add_instruction(make_op("gather", {{"axis", 1 + ii}}), gather_ins, index_litA);
// add a "gather" instruction
gather_ins = info.add_instruction(
make_op("gather", {{"axis", 1 + ii}}), gather_ins, index_litA);
}
return gather_ins;
}
......@@ -402,8 +391,6 @@ struct parse_resize : op_parser<parse_resize>
auto idx_op = get_original_idx_op(coord_trans_mode);
// reshape input to one-dimension
// TODO: We did this in multi dimensions in the dynamic case. Can we do
// the same here?
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
args[0] = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
......@@ -425,11 +412,12 @@ struct parse_resize : op_parser<parse_resize>
ind[out_idx] = static_cast<int64_t>(in_s.index(in_idx));
});
// get the number of dimensions
// std::size_t n_dim = out_lens.size();
// std::vector<std::vector<std::size_t>> vv_ind(2, std::vector<std::size_t>(out_elements));
// std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
// std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
// get the number of dimensions
// std::size_t n_dim = out_lens.size();
// std::vector<std::vector<std::size_t>> vv_ind(2,
// std::vector<std::size_t>(out_elements));
// std::vector<std::vector<std::vector<std::size_t>>> vvv_ind(n_dim, vv_ind);
// std::vector<std::vector<float>> delta(n_dim, std::vector<float>(out_elements));
shape ind_s{shape::int32_type, out_lens};
auto ins_ind = info.add_literal(literal(ind_s, ind));
......
......@@ -1757,6 +1757,39 @@ TEST_CASE(resize_downsample_f_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_outsize_test)
{
// resize using output_size input, rather than scales
migraphx::program p = migraphx::parse_onnx("resize_outsize_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
std::vector<float> dx(sx.elements());
std::iota(dx.begin(), dx.end(), 0.1f);
migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}};
std::vector<float> dy(sx.elements(), 0);
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, dx.data());
pp["Y"] = migraphx::argument(sx, dy.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// clang-format off
std::vector<float> gold = {0.1f, 0.1f, 1.1f, 1.1f, 1.1f, 1.1f,
2.1f, 2.1f, 3.1f, 3.1f, 3.1f, 3.1f,
2.1f, 2.1f, 3.1f, 3.1f, 3.1f, 3.1f,
2.1f, 2.1f, 3.1f, 3.1f, 3.1f, 3.1f};
// clang-format on
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_downsample_f_dyn_test)
{
migraphx::onnx_options options;
......@@ -1776,7 +1809,7 @@ TEST_CASE(resize_downsample_f_dyn_test)
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// clang-format off
std::vector<float> gold = {
0.1f, 1.1f, 3.1f, 4.1f, 6.1f,
......@@ -1787,7 +1820,8 @@ TEST_CASE(resize_downsample_f_dyn_test)
72.1f, 73.1f, 75.1f, 76.1f, 78.1f};
// clang-format on
EXPECT(migraphx::verify::verify_range_with_tolerance(result_vector, migraphx::verify::expected{gold}));
EXPECT(migraphx::verify::verify_range_with_tolerance(result_vector,
migraphx::verify::expected{gold}));
}
TEST_CASE(resize_upsample_f_dyn_test)
......@@ -1806,11 +1840,11 @@ TEST_CASE(resize_upsample_f_dyn_test)
std::iota(dx.begin(), dx.end(), 0.1f);
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, dx.data());
pp["X"] = migraphx::argument(sx, dx.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// clang-format off
std::vector<float> gold = {
0.1f, 0.1f, 1.1f, 2.1f, 2.1f, 3.1f, 4.1f, 4.1f,
......@@ -1825,7 +1859,8 @@ TEST_CASE(resize_upsample_f_dyn_test)
// Using verify_range_with_tolerance() because floating-point
// rounding errorswere observed.
EXPECT(migraphx::verify::verify_range_with_tolerance(result_vector, migraphx::verify::expected{gold}));
EXPECT(migraphx::verify::verify_range_with_tolerance(result_vector,
migraphx::verify::expected{gold}));
}
TEST_CASE(resize_upsample_linear_ac_test)
......
......@@ -277,14 +277,14 @@ TEST_CASE(convolution_backwards_dyn_batch2)
params["x"] = migraphx::argument(input_fixed_shape, x_data.data());
auto result = p.eval(params).back();
//clang-format off
// clang-format off
std::vector<float> gold{12., 0., 21., 0., 27., 0., 33., 0., 24., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 33., 0., 54., 0., 63., 0., 72., 0., 51., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 63., 0., 99., 0., 108., 0.,
117., 0., 81., 0., 0., 0., 0., 0., 0., 0., 0., 0., 93., 0.,
144., 0., 153., 0., 162., 0., 111., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 72., 0., 111., 0., 117., 0., 123., 0., 84.};
//clang-format on
// clang-format on
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment