Commit 75dc1ac3 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

refactor a helper subroutine to placate tidy

parent 33649edd
......@@ -185,6 +185,97 @@ struct parse_resize : op_parser<parse_resize>
{
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
// A helper for one case of parse().
// Dynamic batch: Only args[0] can have a dynamic shape, only the 0'th
// dimension--batch size--can be non-fixed, and the only resize mode allowed is "nearest"
instruction_ref dynamic_nearest_parse(const std::vector<size_t>& out_lens,
const std::vector<double>& vec_scale,
const op_desc& opd,
onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
// coord transform mode
std::string coord_trans_mode = get_coord_trans_mode(info.attributes);
// mode: only nearest and linear modes are supported for now
std::string mode = get_mode(info.attributes);
// rounding option when using "nearest"
std::string nearest_mode = get_nearest_mode(info.attributes);
if(mode == "nearest")
{
auto some_dims = args[0]->get_shape().dyn_dims();
bool mostly_fixed =
std::all_of(some_dims.begin() + 1,
some_dims.end(),
[](shape::dynamic_dimension dd) { return dd.is_fixed(); });
if(not mostly_fixed)
MIGRAPHX_THROW("PARSE_" + opd.op_name +
": dynamic shape inputs other than batch size are not supported");
// Get static dimension set and
// Drop the 0'th dimension,
auto fixed_dims = args[0]->get_shape().to_static(1).lens();
fixed_dims.erase(fixed_dims.begin());
// dimensions of the (scaled) output, also with the 0'th dimension dropped
auto fixed_out_lens = out_lens;
fixed_out_lens.erase(fixed_out_lens.begin());
// create a shape with the scaled lens and no batch dimension
migraphx::shape static_out_shape(args[0]->get_shape().type(), fixed_out_lens);
// map out_idx to in_idx
auto idx_op = get_original_idx_op(coord_trans_mode);
auto nearest_op = get_nearest_op(nearest_mode);
// For each element of static_out_shape, find the matching location of input shape.
// The indexes we find will be an argument to the gather op.
shape_for_each(static_out_shape, [&](const auto& out_idx_v, size_t) {
std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < fixed_dims.size(); ++ii)
{
// Convert this index by scaling.
auto idx_val =
idx_op(fixed_dims[ii], fixed_out_lens[ii], out_idx_v[ii], vec_scale[ii]);
// round the scaled value to an int index
in_idx[ii] = nearest_op(fixed_dims[ii], idx_val);
}
});
instruction_ref gather_ins{args[0]};
// for each static dimension
for(auto ii = 0; ii < fixed_dims.size(); ++ii)
{
std::vector<size_t> in_idx(fixed_out_lens[ii]);
// for range of this dimension's size in output
for(auto len : range(fixed_out_lens[ii]))
{
// Convert this index by scaling.
auto idx_val =
idx_op(fixed_dims[ii], fixed_out_lens[ii], len, vec_scale[ii + 1]);
// round the scaled value to an index
in_idx[len] = nearest_op(fixed_dims[ii], idx_val);
// Put the value into index vector
}
// Create a 1D shape literal
auto index_litA = info.add_literal(literal(
migraphx::shape(migraphx::shape::int64_type, {fixed_out_lens[ii]}), in_idx));
// add a "gather" instruction
gather_ins = info.add_instruction(
make_op("gather", {{"axis", 1 + ii}}), gather_ins, index_litA);
}
return gather_ins;
}
else
{
MIGRAPHX_THROW("PARSE_RESIZE: only nearest_mode supports dynamic batch size input");
}
}
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
......@@ -303,80 +394,7 @@ struct parse_resize : op_parser<parse_resize>
// dimension--batch size--can be non-fixed, and the only resize mode allowed is "nearest"
if(args[0]->get_shape().dynamic())
{
if(mode == "nearest")
{
auto some_dims = args[0]->get_shape().dyn_dims();
bool mostly_fixed =
std::all_of(some_dims.begin() + 1,
some_dims.end(),
[](shape::dynamic_dimension dd) { return dd.is_fixed(); });
if(not mostly_fixed)
MIGRAPHX_THROW(
"PARSE_" + opd.op_name +
": dynamic shape inputs other than batch size are not supported");
// Get static dimension set and
// Drop the 0'th dimension,
auto fixed_dims = args[0]->get_shape().to_static(1).lens();
fixed_dims.erase(fixed_dims.begin());
// dimensions of the (scaled) output, also with the 0'th dimension dropped
auto fixed_out_lens = out_lens;
fixed_out_lens.erase(fixed_out_lens.begin());
// create a shape with the scaled lens and no batch dimension
migraphx::shape static_out_shape(args[0]->get_shape().type(), fixed_out_lens);
// map out_idx to in_idx
auto idx_op = get_original_idx_op(coord_trans_mode);
auto nearest_op = get_nearest_op(nearest_mode);
// For each element of static_out_shape, find the matching location of input shape.
// The indexes we find will be an argument to the gather op.
shape_for_each(static_out_shape, [&](const auto& out_idx_v, size_t) {
std::vector<size_t> in_idx(out_idx_v.size());
for(auto ii = 0; ii < fixed_dims.size(); ++ii)
{
// Convert this index by scaling.
auto idx_val = idx_op(
fixed_dims[ii], fixed_out_lens[ii], out_idx_v[ii], vec_scale[ii]);
// round the scaled value to an int index
in_idx[ii] = nearest_op(fixed_dims[ii], idx_val);
}
});
instruction_ref gather_ins{args[0]};
// for each static dimension
for(auto ii = 0; ii < fixed_dims.size(); ++ii)
{
std::vector<size_t> in_idx(fixed_out_lens[ii]);
// for range of this dimension's size in output
for(auto len : range(fixed_out_lens[ii]))
{
// Convert this index by scaling.
auto idx_val =
idx_op(fixed_dims[ii], fixed_out_lens[ii], len, vec_scale[ii + 1]);
// round the scaled value to an index
in_idx[len] = nearest_op(fixed_dims[ii], idx_val);
// Put the value into index vector
}
// Create a 1D shape literal
auto index_litA = info.add_literal(
literal(migraphx::shape(migraphx::shape::int64_type, {fixed_out_lens[ii]}),
in_idx));
// add a "gather" instruction
gather_ins = info.add_instruction(
make_op("gather", {{"axis", 1 + ii}}), gather_ins, index_litA);
}
return gather_ins;
}
else
{
MIGRAPHX_THROW("PARSE_RESIZE: only nearest_mode supports dynamic batch size input");
}
return dynamic_nearest_parse(out_lens, vec_scale, opd, info, args);
}
else
{
......
......@@ -6287,9 +6287,9 @@ TEST_CASE(resize_downsample_f_dyn_test)
std::vector<int> ind3 = {0, 1, 3, 4, 6};
auto li3 = mm->add_literal(migraphx::literal(si3, ind3));
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), inx, li1);
r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 2}}), r, li2);
r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 3}}), r, li3);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 1}}), inx, li1);
r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 2}}), r, li2);
r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 3}}), r, li3);
mm->add_return({r});
migraphx::onnx_options options;
......
......@@ -1757,7 +1757,6 @@ TEST_CASE(resize_downsample_f_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_outsize_test)
{
// resize using output_size input, rather than scales
......@@ -1770,7 +1769,7 @@ TEST_CASE(resize_outsize_test)
migraphx::shape sy{migraphx::shape::float_type, {1, 1, 4, 6}};
std::vector<float> dy(sx.elements(), 0);
migraphx::parameter_map pp;
pp["X"] = migraphx::argument(sx, dx.data());
pp["Y"] = migraphx::argument(sx, dy.data());
......@@ -1789,7 +1788,6 @@ TEST_CASE(resize_outsize_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_downsample_f_dyn_test)
{
migraphx::onnx_options options;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment