Unverified Commit b8202d61 authored by Attila Dusnoki's avatar Attila Dusnoki Committed by GitHub
Browse files

Add scales attribute parse in upsample for older opset versions (#2336)

parent d7c8b66f
...@@ -181,41 +181,23 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr) ...@@ -181,41 +181,23 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
return nearest_mode; return nearest_mode;
} }
struct parse_resize : op_parser<parse_resize> static std::vector<double> get_scales(const onnx_parser::attribute_map& attr)
{ {
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; } std::vector<double> scales;
if(contains(attr, "scales"))
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
// coord transform mode
std::string coord_trans_mode = get_coord_trans_mode(info.attributes);
// mode: only nearest and linear modes are supported for now
std::string mode = get_mode(info.attributes);
// nearest mode
std::string nearest_mode = get_nearest_mode(info.attributes);
// check exclude_outside, only support 0
if(contains(info.attributes, "exclude_outside") and
info.attributes.at("exclude_outside").i() == 1)
{ {
MIGRAPHX_THROW("PARSE_" + opd.op_name + ": exclude_outside 1 is not supported!"); copy(attr.at("scales").floats(), std::back_inserter(scales));
} }
// input data shape info return scales;
auto in_s = args[0]->get_shape(); }
auto in_lens = in_s.lens();
// output shape is explicitly specified
std::vector<std::size_t> out_lens(in_lens.size());
// scale
std::vector<double> vec_scale;
static void parse_args(const std::vector<instruction_ref>& args,
const std::vector<size_t>& in_lens,
const std::string& op_name,
std::vector<double>& vec_scale,
std::vector<std::size_t>& out_lens)
{
for(const auto& arg : args) for(const auto& arg : args)
{ {
if(arg->name() == "undefined" or arg == args.front()) if(arg->name() == "undefined" or arg == args.front())
...@@ -236,12 +218,12 @@ struct parse_resize : op_parser<parse_resize> ...@@ -236,12 +218,12 @@ struct parse_resize : op_parser<parse_resize>
{ {
auto arg_out_s = arg->eval(); auto arg_out_s = arg->eval();
check_arg_empty(arg_out_s, check_arg_empty(arg_out_s,
"PARSE_" + opd.op_name + ": dynamic output size is not supported!"); "PARSE_" + op_name + ": dynamic output size is not supported!");
arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); }); arg_out_s.visit([&](const auto& ol) { out_lens.assign(ol.begin(), ol.end()); });
if(out_lens.size() != in_lens.size()) if(out_lens.size() != in_lens.size())
{ {
MIGRAPHX_THROW("PARSE_" + opd.op_name + MIGRAPHX_THROW("PARSE_" + op_name +
": specified output size does not match input size"); ": specified output size does not match input size");
} }
...@@ -261,25 +243,71 @@ struct parse_resize : op_parser<parse_resize> ...@@ -261,25 +243,71 @@ struct parse_resize : op_parser<parse_resize>
{ {
auto arg_scale = arg->eval(); auto arg_scale = arg->eval();
check_arg_empty(arg_scale, check_arg_empty(arg_scale,
"PARSE_" + opd.op_name + "PARSE_" + op_name + ": dynamic input scale is not supported!");
": dynamic input scale is not supported!");
arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); }); arg_scale.visit([&](const auto& v) { vec_scale.assign(v.begin(), v.end()); });
}
}
}
}
struct parse_resize : op_parser<parse_resize>
{
std::vector<op_desc> operators() const { return {{"Resize"}, {"Upsample"}}; }
instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/,
onnx_parser::node_info info,
std::vector<instruction_ref> args) const
{
// coord transform mode
std::string coord_trans_mode = get_coord_trans_mode(info.attributes);
// mode: only nearest and linear modes are supported for now
std::string mode = get_mode(info.attributes);
// nearest mode
std::string nearest_mode = get_nearest_mode(info.attributes);
// check exclude_outside, only support 0
if(contains(info.attributes, "exclude_outside") and
info.attributes.at("exclude_outside").i() == 1)
{
MIGRAPHX_THROW("PARSE_" + opd.op_name + ": exclude_outside 1 is not supported!");
}
// input data shape info
auto in_s = args[0]->get_shape();
auto in_lens = in_s.lens();
// output shape is explicitly specified
std::vector<std::size_t> out_lens(in_lens.size());
// scale
std::vector<double> vec_scale = get_scales(info.attributes);
// If `scales` was not an attribute, it must be an input
if(vec_scale.empty())
{
// Depending on the args, it *must* populate the `vec_scale`, and might populate
// `out_lens`
parse_args(args, in_lens, opd.op_name, vec_scale, out_lens);
}
if(in_lens.size() != vec_scale.size()) if(in_lens.size() != vec_scale.size())
{ {
MIGRAPHX_THROW("PARSE_" + opd.op_name + MIGRAPHX_THROW("PARSE_" + opd.op_name + ": ranks of input and scale are different!");
": ranks of input and scale are different!");
} }
std::transform(in_lens.begin(), // if the output was not calculated yet, we update it based on the scales
if(all_of(out_lens.cbegin(), out_lens.cend(), [](auto o) { return o == 0; }))
{
std::transform(
in_lens.begin(),
in_lens.end(), in_lens.end(),
vec_scale.begin(), vec_scale.begin(),
out_lens.begin(), out_lens.begin(),
[&](auto idx, auto scale) { [&](auto idx, auto scale) { return static_cast<std::size_t>(idx * scale); });
return static_cast<std::size_t>(idx * scale);
});
}
}
} }
shape out_s{in_s.type(), out_lens}; shape out_s{in_s.type(), out_lens};
...@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize> ...@@ -288,7 +316,6 @@ struct parse_resize : op_parser<parse_resize>
// reshape input to one-dimension // reshape input to one-dimension
std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())}; std::vector<int64_t> rsp_lens = {static_cast<int64_t>(in_s.elements())};
args[0] = info.make_contiguous(args[0]);
auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]); auto rsp = info.add_instruction(make_op("reshape", {{"dims", rsp_lens}}), args[0]);
if(mode == "nearest") if(mode == "nearest")
......
...@@ -9031,6 +9031,20 @@ def upsample_test(): ...@@ -9031,6 +9031,20 @@ def upsample_test():
return ([node], [X], [Y], [scale_tensor]) return ([node], [X], [Y], [scale_tensor])
@onnx_test()
def upsample_ver7_test():
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 2, 2])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 4, 6])
node = onnx.helper.make_node('Upsample',
inputs=['X'],
outputs=['Y'],
mode='nearest',
scales=[1.0, 1.0, 2.0, 3.0])
return ([node], [X], [Y])
@onnx_test() @onnx_test()
def variable_batch_test(): def variable_batch_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
......
...@@ -6557,9 +6557,8 @@ TEST_CASE(resize_nonstd_input_test) ...@@ -6557,9 +6557,8 @@ TEST_CASE(resize_nonstd_input_test)
auto tx = auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), inx); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), inx);
mm->add_instruction(migraphx::make_op("undefined")); mm->add_instruction(migraphx::make_op("undefined"));
auto tx_cont = mm->add_instruction(migraphx::make_op("contiguous"), tx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), tx_cont); auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), tx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li); auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r}); mm->add_return({r});
...@@ -8418,6 +8417,27 @@ TEST_CASE(upsample_test) ...@@ -8418,6 +8417,27 @@ TEST_CASE(upsample_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(upsample_ver7_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto ix = mm->add_parameter("X", sx);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), ix);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, li);
mm->add_return({r});
auto prog = migraphx::parse_onnx("upsample_ver7_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test_throw_print_error) TEST_CASE(unknown_test_throw_print_error)
{ {
migraphx::onnx_options options; migraphx::onnx_options options;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment